A neural network, and tensor dynamic automatic differentiation implementation for Rust.
```rust for _ in 0..iterations { // array operations are never in-place for corgi, so values never change let input = Array::from((vec![batchsize, inputsize], vec![...])); let target = Array::from((vec![batchsize, outputsize], vec![...]));
let _result = model.forward(input);
let loss = model.backward(target);
// update the parameters, and clear gradients (backward pass only sets gradients)
model.update();
println!("loss: {}", loss);
} ```
c.backward(None);
``
* The Array is responsible differentiates operations done on it for the backward pass.
* No graph structure for ergonomics - an
Array` contains only its children.
* Arrays do note store consumers (at the moment). They store consumer counts instead.
openblas
, or netlib
features can be enabled.tracked()
, or start_tracking()
must be used (see the documentation for details).tracked()
, and untracked()
in array.rs
.for _ in 0..iterations { let mut input = vec![0.0; inputsize * batchsize]; let mut target = vec![0.0; outputsize * batchsize];
// set inputs, and targets
// arrays in corgi should not be mutated after creation, so we initialise the values first
let input = Array::from((vec![batch_size, input_size], input));
let target = Array::from((vec![batch_size, output_size], target));
let _result = model.forward(input);
let loss = model.backward(target);
// update the parameters, and clear gradients (backward pass only sets gradients)
model.update();
println!("loss: {}", loss);
}
* Dynamic computational graph:
rust
let a = arr![5.0].tracked();
let b = arr![2.0].tracked();
let mut c = arr![0.0].tracked();
for _ in 0..10 { c = &c + &(&a * &b); if c[0] > 50.0 { c = &c * &a; } }
assert_eq!(c, arr![195300.0]);
c.backward(None); asserteq!(c.gradient(), arr![1.0]); asserteq!(b.gradient(), arr![97650.0]); assert_eq!(a.gradient(), arr![232420.0]); ``` * Custom operation (still needs some work).
A lot of the library was built around being as dynamic as possible, meaning if chosen well, some design choices might be similar to other dynamic computational graph libraries.
Third-party libraries were used, and can be found in Cargo.toml
.