BitNetMCU is a project focused on the training and inference of low-bit quantized neural networks, specifically designed to run efficiently on low-end microcontrollers like the CH32V003. Quantization aware training (QAT) and fine-tuning of model structure and inference code allowed surpassing 99% Test accuracy on a 16x16 MNIST dataset without using multiplication instructions and in only 2kb of RAM and 16kb of Flash.
Update Introducing a new model architecture based on deep separable convolutions allowed to push the accuracy even further to 99.55% accuracy, meeting state-of-the-art MNIST accuracy for CNNs while still fitting into the same memory constraints. This model requires a hardware multiplier, which is available in many low-end RISC-V and ARM Cortex-M0 microcontrollers.
The training pipeline is based on PyTorch and should run anywhere. The inference engine is implemented in Ansi-C and can be easily ported to any Microcontroller.
You can find a detailed report on the project in the docs/ directory here and CNN details here.
BitNetMCU/
│
├── docs/ # Report
├── mcu/ # MCU specific code for CH32V003
├── modeldata/ # Pre-trained models
│
├── BitNetMCU.py # Pytorch model and QAT classes
├── BitNetMCU_inference.c # C code for inference
├── BitNetMCU_inference.h # Header file for C inference code
├── BitNetMCU_MNIST_test.c # Test script for MNIST dataset
├── BitNetMCU_MNIST_test_data.h# MNIST test data in header format (generated)
├── BitNetMCU_model.h # Model data in C header format (generated)
├── exportquant.py # Script to convert trained model to quantized format
├── test_inference.py # Script to test C implementation of inference
├── training.py # Training script for the neural network
└── trainingparameters.yaml # Configuration file for training parameters
The data pipeline is split into several Python scripts for flexibility:
-
Configuration: Modify
trainingparameters.yamlto set all hyperparameters for training the model. -
Training the Model: The
training.pyscript is used to train the model and store the weights as a.pthfile in themodeldata/folder. The model weights are still in float format at this stage, as they are quantized on-the-fly during training. -
Exporting the Quantized Model: The
exportquant.pyscript is used to convert the model into a quantized format. The quantized model weights are exported to the C header fileBitNetMCU_model.h. -
Optional: Testing the C-Model: Compile and execute
BitNetMCU_MNIST_test.cto test inference of ten digits. The model data is included fromBitNetMCU_MNIST_test_data.h, and the test data is included from theBitNetMCU_MNIST_test_data.hfile. -
Optional: Verification C vs Python Model on full dataset: The inference code, along with the model data, is compiled into a DLL. The
test-inference.pyscript calls the DLL and compares the results with the original Python model. This allows for an accurate comparison to the entire MNIST test data set of 10,000 images. -
Optional: Testing inference on the MCU: follow the instructions in
mcu/readme.md. Porting to architectures other than CH32V003 is straightforward and the files in themcudirectory can serve as a reference.
- 24th April 2024 - First release with Binary, Ternary, 2 bit, 4 bit and 8 bit quantization.
- 2nd May 2024 - tagged version 0.1a
- 8th May 2024 - Added FP1.3.0 Quantization to allow fully multiplication-free inference with 98.9% accuracy.
- 11th May 2024 - Fixes for Linux. Thanks to @donn
- 19th May 2024 - Add support for non-symmetric 4bit quantization scheme that allows for easier inference on MCUs with multiplier. The inference code will now use code optimized for multiplierless MCUs only on RV32 architectures without multiplier.
- 20th May 2024 - Added
quantscaleas a hyperparameter to influence weight scaling. Updated documentation on new quantization schemes. - 26th May 2024 - tagged version 0.2a
- 19th July 2024 - Added octav algorithm to calculate optimal clipping and quantization parameters.
- 26th July 2024 - Added support for NormalFloat4 (NF4) Quantization. Updated documentation
- 7th September 2025 - New CNN architecture based on sequential depthwise separable convolutions allows to reach 99.55% accuracy while still fitting into 16kb Flash and 4kb RAM. See documentation for details.
