torchabc is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.
The core of the package is the TorchABC class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.
This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained file. It is ideal for research, prototyping, and teaching.
The TorchABC class structures a project into the following main steps:
- Dataloaders - load raw data samples.
- Preprocess – transform raw samples.
- Collate - batch preprocessed samples.
- Network - compute model outputs.
- Loss - compute error against targets.
- Optimizer - update model parameters.
- Postprocess - transform outputs into predictions.
Each step corresponds to an abstract method in TorchABC. To use TorchABC, create a concrete subclass and implement these methods.
Install the package.
pip install torchabcGenerate a template using the command line interface.
torchabc --create template.py --minFill out the template by implementing the methods below. The documentation of each method is available here.
import torch
from torchabc import TorchABC
from functools import cached_property
class MyModel(TorchABC):
    
    @cached_property
    def dataloaders(self):
        raise NotImplementedError
    
    @staticmethod
    def preprocess(sample, hparams, flag=''):
        return sample
    @staticmethod
    def collate(samples):
        return torch.utils.data.default_collate(samples)
    @cached_property
    def network(self):
        raise NotImplementedError
    
    @staticmethod
    def loss(outputs, targets, hparams):
        raise NotImplementedError
    @cached_property
    def optimizer(self):
        raise NotImplementedError
    
    @staticmethod
    def postprocess(outputs, hparams):
        return outputsOnce a subclass of TorchABC is implemented, it can be used for training, evaluation, checkpointing, and inference.
model = MyModel()Initialize the model.
model.train(epochs=5, on="train", val="val")Train the model for 5 epochs using the train and val dataloaders.
metrics = model.eval(on="test")Evaluate on the test dataloader and return metrics.
model.save("checkpoint.pth")
model.load("checkpoint.pth")Save and restore the model state.
preds = model(samples)Run predictions on raw input samples.
The TorchABC class defines a standard workflow for PyTorch projects. Some methods are abstract (must be implemented in subclasses), others are optional (can be overridden but have defaults), and a few are concrete (should not be overridden).
| Method | Description | 
|---|---|
| dataloaders | Must return dict[str, torch.utils.data.DataLoader]. Example keys:"train","val","test". | 
| preprocess(sample, hparams, flag='') | Transform a raw dataset sample. Parameters: - sample(Any): raw sample.- hparams(dict): hyperparameters.- flag(str, optional): mode flag.Returns: Tensoror iterable of tensors. | 
| collate(samples) | Collate a batch of preprocessed samples. Parameters: - samples(Iterable[Tensor])Returns: Tensoror iterable of tensors. | 
| network | Must return a torch.nn.Module. Inputs and outputs must use(batch_size, ...)format. | 
| optimizer | Must return a torch.optim.Optimizerforself.network.parameters(). | 
| loss(outputs, targets, hparams) | Compute loss for a batch. Parameters: - outputs(Tensoror iterable)- targets(Tensoror iterable)- hparams(dict)Returns: dict[str, Any]containing key"loss". | 
| postprocess(outputs, hparams) | Convert network outputs into predictions. Parameters: - outputs(Tensoror iterable)- hparams(dict)Returns: predictions ( Any). | 
| Method | Description | 
|---|---|
| scheduler | Learning rate scheduler. May return None,torch.optim.lr_scheduler.LRScheduler, orReduceLROnPlateau. Default isNone. | 
| backward(batch, gas) | Backpropagation step. Parameters: - batch(dict[str, Any]): must contain key"loss".- gas(int): gradient accumulation steps. | 
| metrics(batches, hparams) | Compute evaluation metrics. Parameters: - batches(deque[dict[str, Any]]): batch results.- hparams(dict)Returns: dict[str, Any]. Default computes average loss. | 
| checkpoint(epoch, metrics, out) | Checkpoint step. Saves model if loss improves. Parameters: - epoch(int): epoch number.- metrics(dict[str, float]): validation metrics.- out(strorNone): output path to save checkpoints.Returns: boolindicating early stopping. | 
| move(data) | Move data to current device. Supports Tensor, list, tuple, dict. | 
| detach(data) | Detach data from computation graph. Supports Tensor, list, tuple, dict. | 
| Method | Description | 
|---|---|
| TorchABC(device=None, logger=print, hparams=None, **kwargs) | Initialize the model. Parameters: - device(strortorch.device, optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.- logger(Callable[[dict], None], optional): logging function. Defaults toprint.- hparams(dict, optional): dictionary of hyperparameters.- kwargs: additional attributes stored in the instance. | 
| train(epochs, gas=1, mas=None, on='train', val='val', out=None) | Train the model. Parameters: - epochs(int): number of training epochs.- gas(int, optional): gradient accumulation steps. Defaults to 1.- mas(int, optional): metrics accumulation steps. Defaults togas.- on(str, optional): training dataloader name. Default"train".- val(str, optional): validation dataloader name. Default"val". IfNone, validation is skipped.- out(str, optional): output path to save checkpoints. | 
| eval(on) | Evaluate the model. Parameters: - on(str): dataloader name.Returns: dict[str, float]of evaluation metrics. | 
| __call__(samples) | Run inference on raw samples. Parameters: - samples(Iterable[Any]): raw samples.Returns: postprocessed predictions. | 
| save(path) | Save a checkpoint. Parameters: - path(str): file path. | 
| load(path) | Load a checkpoint. Parameters: - path(str): file path. | 
Get started with simple self-contained examples:
Install the dependencies
poetry install --with examples
Run the examples by replacing <name> with one of the filenames in the examples folder
poetry run python examples/<name>.py
Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC class itself.
