burn-import
is a crate designed to simplify the process of importing models trained in other
machine learning frameworks into the Burn framework. This tool generates a Rust source file that
aligns the imported model with Burn's model and converts tensor data into a format compatible with
Burn.
Currently, burn-import
supports importing ONNX models with a limited set of operators, as it is
still under development.
To import ONNX models, follow these steps:
Add the following code to your build.rs
file:
rust
use burn_import::onnx::ModelGen;
fn main() {
// Generate the model code and state file from the ONNX file.
ModelGen::new()
.input("src/model/mnist.onnx") // Path to the ONNX model
.out_dir("model/") // Directory for the generated Rust source file (under target/)
.run_from_script();
}
Add the following code to the mod.rs
file under src/model
:
rust
pub mod mnist {
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
}
Use the imported model in your code as shown below:
```rust use burn::tensor; use burnndarray::NdArrayBackend; use onnxinference::model::mnist::{Model, INPUT1_SHAPE};
fn main() {
// Create a new model
let model: Model
// Create a new input tensor (all zeros for demonstration purposes)
let input = tensor::Tensor::
// Run the model let output = model.forward(input);
// Print the output println!("{:?}", output); } ```
A working example can be found in the
examples/onnx-inference
directory.
To add support for new operators to burn-import
, follow these steps:
my-model.onnx
) and its components:
cargo r -- ./my-model.onnx ./
my-model.graph.txt
file is generated before the error occurs, providing
information about the ONNX model.my-model.graph.txt
file contains IR information about the model, while the
my-model.rs
file contains an actual Burn model in Rust code. The my-model.json
file contains
the model data.The srs/onnx
directory contains the following ONNX modules (continued):
coalesce.rs
: Coalesces multiple ONNX operators into a single Burn operator. This is useful
for operators that are not supported by Burn but can be represented by a combination of
supported operators.op_configuration.rs
: Contains helper functions for configuring Burn operators from operator
nodes.shape_inference.rs
: Contains helper functions for inferring shapes of tensors for inputs and
outputs of operators.Add unit tests for the new operator in the burn-import/tests/onnx_tests.rs
file. Add the ONNX
file and expected output to the tests/data
directory. Ensure the ONNX file is small, as large
files can increase repository size and make it difficult to maintain and clone. Refer to existing
unit tests for examples.