Naive Bayes

linfa-bayes provides pure Rust implementations of Naive Bayes algorithms for the Linfa toolkit.

The Big Picture

linfa-bayes is a crate in the linfa ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's scikit-learn.

Current state

linfa-bayes currently provides an implementation of the following methods:

Examples

You can find examples in the examples/ directory. To run Gaussian Naive Bayes example, use:

bash $ cargo run --example winequality --release

Show source code

```rust, norun use linfa::metrics::ToConfusionMatrix; use linfa::traits::{Fit, Predict}; use linfabayes::{GaussianNb, Result};

// Read in the dataset and convert targets to binary data let (train, valid) = linfadatasets::winequality() .maptargets(|x| if *x > 6 { "good" } else { "bad" }) .splitwithratio(0.9);

// Train the model let model = GaussianNb::params().fit(&train)?;

// Predict the validation dataset let pred = model.predict(&valid);

// Construct confusion matrix let cm = pred.confusion_matrix(&valid)?;

// classes | bad | good
// bad | 130 | 12
// good | 7 | 10
// // accuracy 0.8805031, MCC 0.45080978 println!("{:?}", cm); println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());

Result::Ok(())

```

To run Multinomial Naive Bayes example, use:

bash $ cargo run --example winequality_multinomial --release

Show source code

```rust, norun use linfa::metrics::ToConfusionMatrix; use linfa::traits::{Fit, Predict}; use linfabayes::{MultinomialNb, Result};

// Read in the dataset and convert targets to binary data let (train, valid) = linfadatasets::winequality() .maptargets(|x| if *x > 6 { "good" } else { "bad" }) .splitwithratio(0.9);

// Train the model let model = MultinomialNb::params().fit(&train)?;

// Predict the validation dataset let pred = model.predict(&valid);

// Construct confusion matrix let cm = pred.confusion_matrix(&valid)?; // classes | bad | good
// bad | 88 | 54
// good | 10 | 7

// accuracy 0.5974843, MCC 0.02000631 println!("{:?}", cm); println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());

Result::Ok(())

```