This commit is contained in:
parent
f848d72f06
commit
7fabe7f5d3
42
README.md
42
README.md
@ -17,25 +17,37 @@ Bellande training framework in Rust for machine learning models
|
|||||||
|
|
||||||
## Example Usage
|
## Example Usage
|
||||||
```rust
|
```rust
|
||||||
use bellande_ai_training_framework::prelude::*;
|
use bellande_artificial_intelligence_training_framework::{
|
||||||
|
core::tensor::Tensor,
|
||||||
|
layer::{activation::ReLU, conv::Conv2d},
|
||||||
|
models::sequential::Sequential,
|
||||||
|
};
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
// Simple single-layer model example
|
||||||
let mut framework = Framework::new()?;
|
fn main() -> Result> {
|
||||||
framework.initialize()?;
|
// Create a simple sequential model
|
||||||
|
let mut model = Sequential::new();
|
||||||
|
|
||||||
// Create model
|
// Add a convolutional layer
|
||||||
let model = Sequential::new()
|
model.add(Box::new(Conv2d::new(
|
||||||
.add(Conv2d::new(3, 64, 3, 1, 1))
|
3, // input channels
|
||||||
.add(ReLU::new())
|
4, // output channels
|
||||||
.add(Linear::new(64, 10));
|
(3, 3), // kernel size
|
||||||
|
Some((1, 1)), // stride
|
||||||
|
Some((1, 1)), // padding
|
||||||
|
true, // use bias
|
||||||
|
)));
|
||||||
|
|
||||||
// Configure training
|
// Create input tensor
|
||||||
let optimizer = Adam::new(model.parameters(), 0.001);
|
let input = Tensor::zeros(&[1, 3, 8, 8]); // batch_size=1, channels=3, height=8, width=8
|
||||||
let loss_fn = CrossEntropyLoss::new();
|
|
||||||
let trainer = Trainer::new(model, optimizer, loss_fn);
|
|
||||||
|
|
||||||
// Train model
|
// Forward pass
|
||||||
trainer.fit(train_loader, Some(val_loader), 100)?;
|
let output = model.forward(&input)?;
|
||||||
|
|
||||||
|
// Print output shape
|
||||||
|
println!("Output shape: {:?}", output.shape());
|
||||||
|
assert_eq!(output.shape()[1], 4); // Verify output channels
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user