Burn Import
burn-import
is a crate designed to simplify the process of importing models trained in other
machine learning frameworks into the Burn framework. This tool generates a Rust source file that
aligns the imported model with Burn's model and converts tensor data into a format compatible with
Burn.
Currently, burn-import
supports importing ONNX models with a limited set of operators, as it is
still under development.
ONNX Operators
List taken from here
- [ ] Abs
- [ ] Acos
- [ ] Acosh
- [x] Add
- [ ] And
- [ ] ArgMax
- [ ] ArgMin
- [ ] Asin
- [ ] Asinh
- [ ] Atan
- [ ] Atanh
- [ ] AveragePool
- [x] BatchNormalization
- [ ] Bernoulli
- [ ] BitShift
- [ ] BitwiseAnd
- [ ] BitwiseNot
- [ ] BitwiseOr
- [ ] BitwiseXor
- [ ] BlackmanWindow
- [ ] Cast
- [ ] CastLike
- [ ] Ceil
- [ ] Celu
- [ ] CenterCropPad
- [ ] Clip
- [ ] Col
- [ ] Compress
- [ ] Concat
- [ ] ConcatFromSequence
- [ ] Constant
- [ ] ConstantOfShape
- [ ] Conv
- [ ] Conv1d
- [x] Conv2d
- [ ] ConvInteger
- [ ] ConvTranspose
- [ ] Cos
- [ ] Cosh
- [ ] CumSum
- [ ] DepthToSpace
- [ ] DequantizeLinear
- [ ] Det
- [ ] DFT
- [ ] Div
- [ ] Dropout
- [ ] DynamicQuantizeLinear
- [ ] Einsum
- [ ] Elu
- [ ] Equal
- [ ] Erf
- [ ] Exp
- [ ] Expand
- [ ] EyeLike
- [x] Flatten
- [ ] Floor
- [ ] Gather
- [ ] GatherElements
- [ ] GatherND
- [ ] Gelu
- [x] Gemm (Linear Layer)
- [ ] GlobalAveragePool
- [ ] GlobalLpPool
- [ ] GlobalMaxPool
- [ ] Greater
- [ ] GreaterOrEqual
- [ ] GridSample
- [ ] GroupNormalization
- [ ] GRU
- [ ] HammingWindow
- [ ] HannWindow
- [ ] Hardmax
- [ ] HardSigmoid
- [ ] HardSwish
- [ ] Identity
- [ ] If
- [ ] Im
- [ ] InstanceNormalization
- [ ] IsInf
- [ ] IsNaN
- [ ] LayerNormalization
- [ ] LeakyRelu
- [ ] Less
- [ ] LessOrEqual
- [ ] Linear
- [ ] Log
- [x] LogSoftmax
- [ ] Loop
- [ ] LpNormalization
- [ ] LpPool
- [ ] LRN
- [ ] LSTM
- [ ] MatMul
- [ ] MatMulInteger
- [ ] Max
- [ ] MaxPool
- [ ] MaxPool1d
- [x] MaxPool2d
- [ ] MaxRoiPool
- [ ] MaxUnpool
- [ ] Mean
- [ ] MeanVarianceNormalization
- [ ] MelWeightMatrix
- [ ] Min
- [ ] Mish
- [ ] Mod
- [ ] Mul
- [ ] Multinomial
- [ ] Neg
- [ ] NegativeLogLikelihoodLoss
- [ ] NonMaxSuppression
- [ ] NonZero
- [ ] Not
- [ ] OneHot
- [ ] Optional
- [ ] OptionalGetElement
- [ ] OptionalHasElement
- [ ] Or
- [ ] Pad
- [ ] Pow
- [ ] PRelu
- [ ] QLinearConv
- [ ] QLinearMatMul
- [ ] QuantizeLinear
- [ ] RandomNormal
- [ ] RandomNormalLike
- [ ] RandomUniform
- [ ] RandomUniformLike
- [ ] Range
- [ ] Reciprocal
- [ ] ReduceL
- [ ] ReduceLogSum
- [ ] ReduceLogSumExp
- [ ] ReduceMax
- [ ] ReduceMean
- [ ] ReduceMin
- [ ] ReduceProd
- [ ] ReduceSum
- [ ] ReduceSumSquare
- [x] Relu
- [ ] Reshape
- [ ] Resize
- [ ] ReverseSequence
- [ ] RNN
- [ ] RoiAlign
- [ ] Round
- [ ] Scan
- [ ] Scatter
- [ ] ScatterElements
- [ ] ScatterND
- [ ] Selu
- [ ] SequenceAt
- [ ] SequenceConstruct
- [ ] SequenceEmpty
- [ ] SequenceErase
- [ ] SequenceInsert
- [ ] SequenceLength
- [ ] SequenceMap
- [ ] Shape
- [ ] Shrink
- [x] Sigmoid
- [ ] Sign
- [ ] Sin
- [ ] Sinh
- [ ] Size
- [ ] Slice
- [ ] Softmax
- [ ] SoftmaxCrossEntropyLoss
- [ ] Softplus
- [ ] Softsign
- [ ] SpaceToDepth
- [ ] Split
- [ ] SplitToSequence
- [ ] Sqrt
- [ ] Squeeze
- [ ] STFT
- [ ] StringNormalizer
- [ ] Sub
- [ ] Sum
- [ ] Tan
- [ ] Tanh
- [ ] TfIdfVectorizer
- [ ] ThresholdedRelu
- [ ] Tile
- [ ] TopK
- [ ] Transpose
- [ ] Trilu
- [ ] Unique
- [ ] Unsqueeze
- [ ] Upsample
- [ ] Where
- [ ] Xor
Usage
Importing ONNX models
To import ONNX models, follow these steps:
Add the following code to your build.rs
file:
rust
use burn_import::onnx::ModelGen;
fn main() {
// Generate the model code and state file from the ONNX file.
ModelGen::new()
.input("src/model/mnist.onnx") // Path to the ONNX model
.out_dir("model/") // Directory for the generated Rust source file (under target/)
.run_from_script();
}
Add the following code to the mod.rs
file under src/model
:
rust
pub mod mnist {
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
}
Use the imported model in your code as shown below:
```rust
mod model;
use burn::tensor;
use burn_ndarray::NdArrayBackend;
use model::mnist::Model;
fn main() {
// Create a new model
let model: Model> = Model::new();
// Create a new input tensor (all zeros for demonstration purposes)
let input = tensor::Tensor::, 4>::zeros([1, 1, 28, 28]);
// Run the model
let output = model.forward(input);
// Print the output
println!("{:?}", output);
}
```
A working example can be found in the
examples/onnx-inference
directory.
Adding new operators
To add support for new operators to burn-import
, follow these steps:
- Optimize the ONNX model using onnxoptimizer. This will
remove unnecessary operators and constants, making the model easier to understand.
- Use the Netron app to visualize the ONNX model.
- Generate artifact files for the ONNX model (
my-model.onnx
) and its components:
cargo r -- ./my-model.onnx ./
- Implement the missing operators when you encounter an error stating that the operator is not
supported. Ideally, the
my-model.graph.txt
file is generated before the error occurs, providing
information about the ONNX model.
- The newly generated
my-model.graph.txt
file contains IR information about the model, while the
my-model.rs
file contains an actual Burn model in Rust code. The my-model.json
file contains
the model data.
The srs/onnx
directory contains the following ONNX modules (continued):
coalesce.rs
: Coalesces multiple ONNX operators into a single Burn operator. This is useful
for operators that are not supported by Burn but can be represented by a combination of
supported operators.
op_configuration.rs
: Contains helper functions for configuring Burn operators from operator
nodes.
shape_inference.rs
: Contains helper functions for inferring shapes of tensors for inputs and
outputs of operators.
Add unit tests for the new operator in the burn-import/tests/onnx_tests.rs
file. Add the ONNX
file and expected output to the tests/data
directory. Ensure the ONNX file is small, as large
files can increase repository size and make it difficult to maintain and clone. Refer to existing
unit tests for examples.
Resources
- PyTorch to ONNX
- ONNX to Pytorch
- ONNX Intro
- ONNX Operators
- ONNX Protos
- ONNX Optimizer
- Netron