Rust bindings for the C++ api of PyTorch. The goal of the tch
crate is to
provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It
aims at staying as close as possible to the original C++ api. More idiomatic
rust bindings could then be developed on top of this. The
documentation can be found on docs.rs.
The code generation part for the C api on top of libtorch comes from ocaml-torch.
This crate requires the C++ PyTorch library (libtorch) in version v2.0.0 to be available on your system. You can either:
LIBTORCH
environment variable.LIBTORCH_USE_PYTORCH=1
.LIBTORCH
is not set, the
build script can download a pre-built binary version of libtorch by using
the download-libtorch
feature. By default a CPU version is used. The
TORCH_CUDA_VERSION
environment variable can be set to cu117
in order to
get a pre-built binary using CUDA 11.7.On linux platforms, the build script will look for a system-wide libtorch
library in /usr/lib/libtorch.so
.
If the LIBTORCH_USE_PYTORCH
environment variable is set, the active python
interpreter is called to retrieve information about the torch python package.
This version is then linked against.
libtorch
from the
PyTorch website download section and extract
the content of the zip file..bashrc
or equivalent, where /path/to/libtorch
is the path to the directory that was created when unzipping the file.
bash
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
The header files location can also be specified separately from the shared library via
the following:
```bash
include
directory.export LIBTORCH_INCLUDE=/path/to/libtorch/
lib
directory.export LIBTORCH_LIB=/path/to/libtorch/
``
- For Windows users, assuming that
X:\path\tolibtorchis the unzipped libtorch directory.
- Navigate to Control Panel -> View advanced system settings -> Environment variables.
- Create the
LIBTORCHvariable and set it to
X:\path\tolibtorch.
- Append
X:\path\tolibtorchlibto the
Path` variable.
If you prefer to temporarily set environment variables, in PowerShell you can run
powershell
$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
- You should now be able to run some examples, e.g. cargo run --example basics
.
As per the pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.
It is recommended to use the MSVC Rust toolchain (e.g. by installing stable-x86_64-pc-windows-msvc
via rustup) rather than a MinGW based one as PyTorch has compatibilities issues with MinGW.
When setting environment variable LIBTORCH_STATIC=1",
libtorchis statically
linked rather than using the dynamic libraries. The pre-compiled artifacts don't
seem to include
libtorch.a` by default so this would have to be compiled
manually, e.g. via the following:
```bash git clone -b v2.0.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1 cd pytorch-static USECUDA=OFF BUILDSHARED_LIBS=OFF python setup.py build
```
This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.
```rust use tch::Tensor;
fn main() { let t = Tensor::of_slice(&[3, 1, 4, 1, 5]); let t = t * 2; t.print(); } ```
PyTorch provides automatic differentiation for most tensor operations
it supports. This is commonly used to train models using gradient
descent. The optimization is performed over variables which are created
via a nn::VarStore
by defining their shapes and initializations.
In the example below my_module
uses two variables x1
and x2
which initial values are 0. The forward pass applied to tensor xs
returns xs * x1 + exp(xs) * x2
.
Once the model has been generated, a nn::Sgd
optimizer is created.
Then on each step of the training loop:
VarStore
are modified accordingly.```rust use tch::nn::{Module, OptimizerConfig}; use tch::{kind, nn, Device, Tensor};
fn my_module(p: nn::Path, dim: i64) -> impl nn::Module { let x1 = p.zeros("x1", &[dim]); let x2 = p.zeros("x2", &[dim]); nn::func(move |xs| xs * &x1 + xs.exp() * &x2) }
fn gradientdescent() { let vs = nn::VarStore::new(Device::Cpu); let mymodule = mymodule(vs.root(), 7); let mut opt = nn::Sgd::default().build(&vs, 1e-2).unwrap(); for _idx in 1..50 { // Dummy mini-batches made of zeros. let xs = Tensor::zeros(&[7], kind::FLOATCPU); let ys = Tensor::zeros(&[7], kind::FLOATCPU); let loss = (mymodule.forward(&xs) - ys).powtensorscalar(2).sum(kind::Kind::Float); opt.backward_step(&loss); } } ```
The nn
api can be used to create neural network architectures, e.g. the following code defines
a simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.
```rust use anyhow::Result; use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
const IMAGEDIM: i64 = 784; const HIDDENNODES: i64 = 128; const LABELS: i64 = 10;
fn net(vs: &nn::Path) -> impl Module { nn::seq() .add(nn::linear( vs / "layer1", IMAGEDIM, HIDDENNODES, Default::default(), )) .addfn(|xs| xs.relu()) .add(nn::linear(vs, HIDDENNODES, LABELS, Default::default())) }
pub fn run() -> Result<()> { let m = tch::vision::mnist::loaddir("data")?; let vs = nn::VarStore::new(Device::Cpu); let net = net(&vs.root()); let mut opt = nn::Adam::default().build(&vs, 1e-3)?; for epoch in 1..200 { let loss = net .forward(&m.trainimages) .crossentropyforlogits(&m.trainlabels); opt.backwardstep(&loss); let testaccuracy = net .forward(&m.testimages) .accuracyforlogits(&m.testlabels); println!( "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", epoch, f64::from(&loss), 100. * f64::from(&test_accuracy), ); } Ok(()) } ```
More details on the training loop can be found in the detailed tutorial.
The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.
The example can then be run via the following command:
bash
cargo run --example pretrained-models -- resnet18.ot tiger.jpg
This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.
```rust // First the image is loaded and resized to 224x224. let image = imagenet::loadimageandresize(imagefile)?;
// A variable store is created to hold the model parameters.
let vs = tch::nn::VarStore::new(tch::Device::Cpu);
// Then the model is built on this variable store, and the weights are loaded.
let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);
vs.load(weight_file)?;
// Apply the forward pass of the model to get the logits and convert them
// to probabilities via a softmax.
let output = resnet18
.forward_t(&image.unsqueeze(0), /*train=*/ false)
.softmax(-1);
// Finally print the top 5 categories and their associated probabilities.
for (probability, class) in imagenet::top(&output, 5).iter() {
println!("{:50} {:5.2}%", class, 100.0 * probability)
}
```
Further examples include: * A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks. * Neural style transfer uses a pre-trained VGG-16 model to compose an image in the style of another image (pre-trained weights: vgg16.ot). * Some ResNet examples on CIFAR-10. * A tutorial showing how to deploy/run some Python trained models using TorchScript JIT. * Some Reinforcement Learning examples using the OpenAI Gym environment. This includes a policy gradient example as well as an A2C implementation that can run on Atari games. * A Transfer Learning Tutorial shows how to finetune a pre-trained ResNet model on a very small dataset. * A simplified version of GPT similar to minGPT. * A Stable Diffusion implementation following the lines of hugginface's diffusers library.
External material:
* A tutorial showing how to use Torch to compute option prices and greeks.
* tchrs-opencv-webcam-inference uses tch-rs
and opencv
to run inference
on a webcam feed for some Python trained model based on mobilenet v3.
See some details in this thread.
Check this issue.
See this issue, this could
be caused by rust-analyzer not knowing about the proper environment variables
like LIBTORCH
and LD_LIBRARY_PATH
.
tch-rs
is distributed under the terms of both the MIT license
and the Apache license (version 2.0), at your option.
See LICENSE-APACHE, LICENSE-MIT for more details.