This library aims to be a complete deep learning framework with extreme flexibility written in Rust. The goal would be to satisfy researchers as well as practitioners making it easier to experiment, train and deploy your models.
Sections
metric
, logging
and checkpointing
🌟The best way to get started with burn is the look at the examples. Also, this may be a good idea to checkout the main components to get a quick overview of how to use burn.
For now there is only one example, but more to come 💪..
The MNIST example is not just of small script that shows you how to train a basic model, but it's a quick one showing you how to:
The example can be run like so:
```bash git clone https://github.com/burn-rs/burn.git cd burn
echo "Using ndarray backend" cargo run --example mnist --release --features ndarray # CPU NdArray Backend - f32 - single thread cargo run --example mnist --release --features ndarray-blas-openblas # CPU NdArray Backend - f32 - blas with openblas cargo run --example mnist --release --features ndarray-blas-netlib # CPU NdArray Backend - f32 - blas with netlib echo "Using tch backend" export TORCHCUDAVERSION=cu113 # Set the cuda version cargo run --example mnist --release --features tch-gpu # GPU Tch Backend - f16 cargo run --example mnist --release --features tch-cpu # CPU Tch Backend - f32 ```
Knowing the main components will be of great help when starting playing with burn
.
Almost everything is based on the Backend
trait, which allows to run tensor operations with different implementations without having to change your code.
A backend does not necessary have autodiff capabilities, therefore you can use ADBackend
when you require it.
The Tensor
struct is at the core of the burn
framework.
It takes two generic parameters, the Backend
and the number of dimensions D
,
```rust use burn::tensor::{Tensor, Shape, Data}; use burn::tensor::backend::{Backend, NdArrayBackend, TchBackend};
fn myfunc
fn main() {
myfunc
The Module
derive let your create your own neural network module similar to PyTorch.
```rust use burn::nn; use burn::module::{Param, Module}; use burn::tensor::backend::Backend;
struct MyModule
Note that only the fields wrapped inside Param
are updated during training, and the other ones should implement Clone
.
The Forward
trait can also be implemented by your module.
```rust use burn::module::Forward; use burn::tensor::Tensor;
impl
for _ in 0..self.repeat {
x = self.my_param.forward(x);
}
x
} } ```
Note that you can implement multiple time the Forward
trait with different inputs and outputs.
The Config
derive lets you define serializable and deserializable configurations or hyper-parameters for your modules or any components.
```rust use burn::config::Config;
struct MyConfig { #[config(default = 1.0e-6)] pub epsilon: usize, pub dim: usize, } ``` The derive also adds useful methods to your config.
rust
fn my_func() {
let config = MyConfig::new(100);
println!("{}", config.epsilon); // 1.0.e-6
println!("{}", config.dim); // 100
let config = MyConfig::new(100).with_epsilon(1.0e-8);
println!("{}", config.epsilon); // 1.0.e-8
}
The Learner
is the main struct
that let you train a neural network with support for logging
, metric
, checkpointing
and more.
In order to create a learner, you must use the LearnerBuilder
.
```rust use burn::train::LearnerBuilder;
let learner = LearnerBuilder::new("/tmp/artifactdir")
.metrictrainplot(AccuracyMetric::new())
.metricvalidplot(AccuracyMetric::new())
.metrictrain(LossMetric::new())
.metricvalid(LossMetric::new())
.withfilecheckpointer::
See this example for a real usage.
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.