Deep Learning library using [custos] and [custos-math].
external (C) dependencies: OpenCL, CUDA, nvrtc, cublas, BLAS
There are two features available that are enabled by default: - cuda ... CUDA, nvrtc and cublas are needed to run - opencl ... OpenCL is needed
If you deactivate them (add default-features = false
and provide no additional features), only the CPU device can be used.
For all feature configurations, a BLAS library needs to be installed on the system.
```toml [dependencies] gradients = "0.1.0"
```
(if this example does not compile, consider looking here)
Use a struct that implements the NeuralNetwork trait to define which layers you want to use:
```rust use gradients::purpur::{CSVLoader, CSVReturn, Converter}; use gradients::{ correctclasses, nn::{cce, ccegrad}, range, Adam, AsDev, CLDevice, Linear, NeuralNetwork, OnehotOp, ReLU, Softmax, };
pub struct Network
You can download the mnist dataset here.
```rust // use cpu (no features enabled): let device = gradients::CPU::new().select(); // use cuda device (cuda feature enabled): let device = gradients::CudaDevice::new(0).unwrap().select(); // use opencl device (opencl feature enabled): let device = CLDevice::new(0).unwrap().select();
let loader = CSVLoader::new(true);
let loadeddata: CSVReturn
let i = Matrix::from(( &device, (loadeddata.samplecount, loadeddata.features), &loadeddata.x, )); let i = i / 255.;
let y = Matrix::from((&device, (loadeddata.samplecount, 1), &loaded_data.y)); let y = device.onehot(y);
let mut net = Network { lin1: Linear::new(784, 128), lin2: Linear::new(128, 10), lin3: Linear::new(10, 10), ..Default::default() }; ```
Training loop:
```rust let mut opt = Adam::new(0.01);
for epoch in range(200) { let preds = net.forward(i); let correcttraining = correctclasses( &loadeddata.y.asusize(), preds) as f32;
let loss = cce(&device, &preds, &y);
println!("epoch: {epoch}, loss: {loss}, training_acc: {acc}", acc=correct_training / loaded_data.sample_count() as f32);
let grad = cce_grad(&device, &preds, &y);
net.backward(grad);
opt.step(&device, net.params());
} ```