Documentation
Complete reference for the AxiomML experiment framework. All configuration parameters, API methods, and advanced usage covered.
Overview
AxiomML is a modular, config-driven ML experiment framework built on PyTorch. It separates experiment logic from configuration, enabling reproducible runs with zero boilerplate. Features include:
- YAML/JSON config — Declare every experiment parameter in a single file.
- W&B integration — Automatic metric logging, artifact tracking, and rich media.
- Distributed training — DDP with mixed precision, gradient checkpointing, and sync BN.
- Checkpoint management — Top-K saving by metric, auto-resume, versioned artifacts.
- Modular components — Swap models, datasets, optimizers, schedulers without touching training code.
Installation
# From PyPI
pip install axioml
# From source (latest)
pip install git+https://github.com/anomalyco/AxiomML.git
Requires Python 3.10+ and PyTorch 2.0+. CUDA is optional but recommended.
Quick Start
Create a config file and run an experiment in three steps:
# 1. Create config.yaml
model: mlp
dataset: mnist
trainer:
epochs: 5
batch_size: 64
learning_rate: 0.001
# 2. Write train.py
from axioml import Experiment
exp = Experiment("config.yaml")
exp.train()
metrics = exp.evaluate()
# 3. Run
python train.py
Configuration System
Configs are YAML files (or JSON) that define every aspect of an experiment. The system supports:
- Hierarchical merge — Base configs with overrides.
- CLI overrides —
python train.py --config trainer.epochs=20 - Presets — Predefined configs for common architectures.
- Type coercion — Automatic type conversion from YAML to Python types.
Below is the complete parameter reference organized by section.
Experiment Config
Top-level experiment settings.
| Parameter | Type | Default | Description |
|---|---|---|---|
| name | string | "experiment" | Experiment name (used in W&B run name & checkpoint dirs). |
| seed | int | 42 | Global random seed (Python, NumPy, PyTorch, CUDA). |
| device | string | "cuda" | Device: "cuda", "cpu", or "auto" (auto picks CUDA if available). |
| mixed_precision | string | "fp16" | Mixed precision mode: "fp16", "bf16", or "no". |
| output_dir | string | "./outputs" | Root directory for all experiment outputs. |
| resume_from | string | null | Path to checkpoint to resume from. Auto-loads state dict, optimizer, epoch. |
Data / Dataset Config
Configure dataset loading, transforms, batching, and splitting.
| Parameter | Type | Default | Description |
|---|---|---|---|
| data.dataset | string | — | Dataset name (see Dataset Registry). |
| data.root | string | "./data" | Root directory to cache/download datasets. |
| data.transforms | list | [] | List of transform names or configs (e.g. ["ToTensor", "Normalize"]). |
| data.splits | object | {train: 0.8, val: 0.1, test: 0.1} | Dataset split ratios or named split configs. |
| data.batch_size | int | 64 | Batch size per device (auto-scaled in distributed mode). |
| data.num_workers | int | 4 | DataLoader worker processes. |
| data.shuffle | bool | true | Shuffle training data each epoch. |
| data.pin_memory | bool | true | Pin memory in DataLoader for faster GPU transfer. |
| data.drop_last | bool | false | Drop last incomplete batch. |
| data.persistent_workers | bool | true | Keep worker processes alive between epochs. |
| data.stratify | bool | false | Stratified train/val split by labels. |
| data.cache | bool | true | Cache preprocessed dataset to disk. |
| data.subset | int | null | Use only N samples (useful for debugging). |
Model Config
Select model architecture and load pretrained weights.
| Parameter | Type | Default | Description |
|---|---|---|---|
| model.name | string | — | Model architecture name (see Model Zoo). |
| model.params | object | {} | Architecture-specific kwargs (e.g., hidden_dims, num_layers, dropout). |
| model.pretrained | bool|string | false | Load pretrained weights. Set to true for default, or a path/URL. |
| model.freeze | bool | false | Freeze all model parameters (for transfer learning / feature extraction). |
| model.compile | bool | false | Use torch.compile for graph-optimized execution. |
| model.ema_decay | float | null | Exponential Moving Average decay rate. If set, maintains EMA weights. |
Optimizer Config
| Parameter | Type | Default | Description |
|---|---|---|---|
| optimizer.name | string | "adam" | Optimizer: "adam", "adamw", "sgd", "rmsprop", "adamax", "nadam". |
| optimizer.lr | float | 0.001 | Learning rate. |
| optimizer.weight_decay | float | 0.0 | Weight decay (L2 penalty). |
| optimizer.momentum | float | 0.9 | Momentum factor (SGD/RMSProp). |
| optimizer.betas | [float,float] | [0.9, 0.999] | Adam betas. |
| optimizer.eps | float | 1e-8 | Numerical stability epsilon. |
| optimizer.grad_clip_norm | float | null | Max gradient norm for clipping (null = no clip). |
| optimizer.grad_clip_value | float | null | Max gradient value for clipping (alternative to norm). |
Scheduler Config
| Parameter | Type | Default | Description |
|---|---|---|---|
| scheduler.name | string | "cosine" | LR scheduler: "cosine", "step", "multistep", "exponential", "plateau", "onecycle", "warmup_cosine", "linear". |
| scheduler.params | object | {} | Scheduler-specific kwargs (e.g., step_size, gamma, milestones, T_max, eta_min, patience). |
| scheduler.warmup_epochs | int | 0 | Linear warmup epochs before scheduler starts. |
| scheduler.warmup_lr | float | 1e-6 | Starting LR during warmup. |
| scheduler.interval | string | "epoch" | Step frequency: "epoch" or "step" (batch). |
Trainer Config
| Parameter | Type | Default | Description |
|---|---|---|---|
| trainer.epochs | int | 10 | Number of training epochs. |
| trainer.max_steps | int | null | Max training steps (overrides epochs if set). |
| trainer.batch_size | int | null | Per-device batch size (overrides data.batch_size for training). |
| trainer.grad_accumulation | int | 1 | Gradient accumulation steps (effective batch = batch_size * accum). |
| trainer.validate_every | int | 1 | Validate every N epochs (or N steps if interval="step"). |
| trainer.save_every | int | 1 | Save checkpoint every N epochs. |
| trainer.log_every | int | 10 | Log metrics every N training steps. |
| trainer.early_stop_patience | int | null | Stop training if validation metric does not improve for N epochs. |
| trainer.early_stop_metric | string | "val_loss" | Metric to monitor for early stopping. |
| trainer.early_stop_mode | string | "min" | "min" or "max" — whether lower or higher is better. |
| trainer.verbose | bool | true | Print progress bar and epoch summaries. |
| trainer.profiler | string | null | Profiler: "simple", "advanced", or null. Records timing per component. |
| trainer.detect_anomaly | bool | false | Enable PyTorch autograd anomaly detection (slower, for debugging). |
Logging Config
| Parameter | Type | Default | Description |
|---|---|---|---|
| logging.project | string | "axioml" | W&B project name. |
| logging.run_name | string | null | W&B run name (auto-generated from config hash if not set). |
| logging.tags | [string] | [] | List of tags for W&B run organization. |
| logging.notes | string | null | Free-text notes for the run. |
| logging.log_images | bool | false | Log sample images/grids to W&B (useful for vision tasks). |
| logging.log_histograms | bool | false | Log weight/gradient histograms each epoch. |
| logging.log_model | bool | true | Upload model checkpoints as W&B artifacts. |
| logging.watch_model | bool | false | Enable W&B model watch (gradients & parameters). |
| logging.watch_freq | int | 100 | Frequency (steps) for model watch logging. |
| logging.console_format | string | "rich" | Console output: "rich" (Rich tables) or "simple" (plain text). |
| logging.quiet | bool | false | Suppress non-essential console output. |
Checkpoint Config
| Parameter | Type | Default | Description |
|---|---|---|---|
| checkpoint.dir | string | "./checkpoints" | Directory to save checkpoints (relative to output_dir). |
| checkpoint.top_k | int | 3 | Keep only top-K checkpoints by metric (cleans older ones). |
| checkpoint.metric | string | "val_loss" | Metric for ranking checkpoints. |
| checkpoint.mode | string | "min" | "min" or "max" — lower/higher metric is better. |
| checkpoint.save_best | bool | true | Save best checkpoint separately as best.pt. |
| checkpoint.save_last | bool | true | Save last checkpoint separately as last.pt. |
| checkpoint.save_optimizer | bool | true | Include optimizer state in checkpoint (required for resume). |
| checkpoint.save_scheduler | bool | true | Include scheduler state in checkpoint. |
| checkpoint.filename_template | string | "epoch={epoch}-{metric}={value:.4f}" | Checkpoint filename pattern. |
| checkpoint.clean_on_start | bool | false | Remove previous checkpoints in dir before starting. |
Distributed Config
| Parameter | Type | Default | Description |
|---|---|---|---|
| distributed.enabled | bool | false | Enable distributed training via DDP. |
| distributed.backend | string | "nccl" | Distributed backend: "nccl" (GPU), "gloo" (CPU), "mpi". |
| distributed.world_size | int | null | Number of processes (auto-detected from torchrun if null). |
| distributed.master_addr | string | "localhost" | Master node address for distributed init. |
| distributed.master_port | int | 29500 | Master node port for distributed init. |
| distributed.find_unused_parameters | bool | false | Enable DDP find_unused_parameters (slower but avoids hangs). |
| distributed.sync_bn | bool | false | Replace BatchNorm with SyncBatchNorm across devices. |
| distributed.gradient_checkpointing | bool | false | Enable gradient checkpointing (trade compute for memory). |
| distributed.mixed_precision | string | null | Override top-level mixed_precision for distributed (defaults to top-level value). |
torchrun --nproc_per_node=4 train.py --config distributed.enabled=true
Full Example Config
# configs/full_example.yaml
# Top-level
name: my_experiment
seed: 42
device: cuda
mixed_precision: fp16
output_dir: ./outputs
# Data
data:
dataset: cifar10
root: ./data
transforms:
- RandomHorizontalFlip
- RandomCrop: {size: 32, padding: 4}
- ToTensor
- Normalize: {mean: [0.5, 0.5, 0.5], std: [0.5, 0.5, 0.5]}
batch_size: 128
num_workers: 8
shuffle: true
splits:
train: 0.8
val: 0.1
test: 0.1
# Model
model:
name: resnet18
params:
num_classes: 10
pretrained: false
compile: true
# Optimizer
optimizer:
name: adamw
lr: 0.0003
weight_decay: 0.01
betas: [0.9, 0.95]
# Scheduler
scheduler:
name: warmup_cosine
params:
T_max: 100
eta_min: 1e-6
warmup_epochs: 5
# Trainer
trainer:
epochs: 100
grad_accumulation: 2
validate_every: 1
early_stop_patience: 20
# Logging
logging:
project: axioml-demo
tags: [resnet18, cifar10, adamw]
log_images: true
log_histograms: true
# Checkpoint
checkpoint:
dir: ./checkpoints
top_k: 3
metric: val_acc
mode: max
# Distributed
distributed:
enabled: false
sync_bn: false
Experiment API
The Experiment class is the main entry point. It accepts a config (path, dict, or parsed config object).
Constructor
Experiment(
config: str | dict | Config,
distributed: bool = False,
resume: str | None = None,
)
Methods
| Method | Returns | Description |
|---|---|---|
.train() | None | Run the full training loop. Logs to W&B, saves checkpoints. |
.evaluate(split="test") | dict | Evaluate on given split. Returns {loss, acc, f1, ...}. |
.predict(x) | Tensor | Run inference on a single input or batch. |
.export(path, format="torchscript") | None | Export model to TorchScript or ONNX. |
.save_checkpoint(path=None) | str | Save a checkpoint manually. Returns path. |
.load_checkpoint(path) | None | Load checkpoint (model, optimizer, scheduler state). |
.log_metrics(metrics, step=None) | None | Log custom metrics to W&B and console. |
.log_artifact(path, type="model") | None | Log an arbitrary file as a W&B artifact. |
.summary() | dict | Return run summary with all final metrics. |
.get_model() | nn.Module | Return the underlying PyTorch model. |
.get_optimizer() | Optimizer | Return the optimizer instance. |
.get_scheduler() | Scheduler | Return the LR scheduler instance. |
.get_train_loader() | DataLoader | Return the training DataLoader. |
.get_val_loader() | DataLoader | Return the validation DataLoader. |
.get_test_loader() | DataLoader | Return the test DataLoader. |
Trainer API
The Trainer handles the training loop, validation, checkpointing, and logging. Accessible via exp.trainer.
| Method | Description |
|---|---|
train() | Execute full training loop with callbacks. |
train_epoch() | Run a single training epoch. |
validate() | Run full validation loop. |
predict(batch) | Forward pass without gradient tracking. |
get_lr() | Return current learning rate(s). |
Properties
| Property | Type | Description |
|---|---|---|
current_epoch | int | Current epoch (0-indexed). |
global_step | int | Total optimizer steps taken. |
best_metric | float | Best validation metric value so far. |
best_epoch | int | Epoch at which best metric was achieved. |
is_distributed | bool | Whether running in DDP mode. |
device | torch.device | Current device (rank-aware in DDP). |
Data / Dataset API
Built-in Datasets
| Dataset ID | Description | Type |
|---|---|---|
| mnist | MNIST handwritten digits (28x28, 10 classes). | Vision |
| fashion_mnist | Fashion-MNIST clothing items (28x28, 10 classes). | Vision |
| cifar10 | CIFAR-10 tiny images (32x32, 10 classes). | Vision |
| cifar100 | CIFAR-100 tiny images (32x32, 100 classes). | Vision |
| svhn | SVHN street view house numbers (32x32, 10 classes). | Vision |
| imagenet | ImageNet subset or full (configurable root). | Vision |
| text8 | Text8 character-level language modeling. | Text |
| sst2 | SST-2 sentiment classification. | Text |
Custom Datasets
Register any torch.utils.data.Dataset subclass:
from axioml.data import register_dataset
@register_dataset("my_dataset")
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, split="train", transforms=None):
# ...
pass
def __len__(self): pass
def __getitem__(self, idx): pass
Transforms
Transforms are specified by name in config. Built-in transforms include:
| Transform | Description |
|---|---|
| ToTensor | Convert PIL/ndarray to tensor. |
| Normalize | Normalize with mean/std. |
| RandomHorizontalFlip | Random horizontal flip (p=0.5). |
| RandomVerticalFlip | Random vertical flip (p=0.5). |
| RandomCrop | Random crop with optional padding. |
| RandomRotation | Random rotation by degrees. |
| ColorJitter | Random brightness/contrast/saturation/hue. |
| Resize | Resize to target dimensions. |
| CenterCrop | Center crop to target size. |
Models (Model Zoo)
AxiomML includes a registry of built-in model architectures. Custom models can be registered similarly to datasets.
Built-in Models
| Model ID | Description | Params (config.model.params) |
|---|---|---|
| mlp | Multi-layer perceptron. | hidden_dims, dropout, activation |
| lenet | LeNet-5 for small images. | num_classes |
| resnet18 | ResNet-18. | num_classes, pretrained |
| resnet34 | ResNet-34. | num_classes, pretrained |
| resnet50 | ResNet-50. | num_classes, pretrained |
| vit_tiny | Vision Transformer (Tiny). | img_size, patch_size, embed_dim, depth, num_heads, num_classes |
| vit_small | Vision Transformer (Small). | Same as vit_tiny with larger defaults. |
| vit_base | Vision Transformer (Base). | Same as vit_tiny with base defaults. |
Registering Custom Models
from axioml.models import register_model
@register_model("my_net")
class MyNet(nn.Module):
def __init__(self, num_classes=10, **kwargs):
super().__init__()
# ...
def forward(self, x):
return x
Logging API
AxiomML uses a Logger abstraction that writes to multiple backends simultaneously: W&B, console (Rich), and file.
| Method | Description |
|---|---|
log_metrics(metrics, step) | Log scalar metrics (loss, acc, lr, etc.). |
log_image(key, image, step) | Log a single image or tensor. |
log_images(key, images, step) | Log a grid of images. |
log_histogram(key, values, step) | Log a histogram of values. |
log_figure(key, figure, step) | Log a matplotlib figure. |
log_text(key, text, step) | Log text/table data. |
log_artifact(path, type) | Log any file as an artifact. |
log_hyperparameters(config) | Log all config params as W&B hyperparameters. |
watch(model, log_freq) | Start W&B model watch (gradients & parameters). |
unwatch() | Stop model watch. |
finish() | Finalize all loggers (upload, close). |
Callbacks
AxiomML supports a callback system for hooking into the training lifecycle. Callbacks are classes that implement one or more of the following methods:
| Hook | When it fires |
|---|---|
on_train_start(trainer) | Before the training loop begins. |
on_train_end(trainer) | After the training loop ends. |
on_epoch_start(trainer) | Start of each epoch. |
on_epoch_end(trainer) | End of each epoch. |
on_batch_start(trainer, batch) | Before each training batch. |
on_batch_end(trainer, outputs) | After each training batch (loss computed). |
on_validation_start(trainer) | Before validation loop. |
on_validation_end(trainer, metrics) | After validation loop with metrics dict. |
on_checkpoint_save(trainer, path) | After a checkpoint is saved. |
on_log(trainer, metrics, step) | On each logging step. |
from axioml.callbacks import Callback
class MyCallback(Callback):
def on_epoch_end(self, trainer):
print(f"Epoch {trainer.current_epoch} done")
Built-in callbacks (auto-enabled):
- CheckpointCallback — Saves top-K checkpoints by metric.
- LoggingCallback — Metrics logging to W&B and console.
- EarlyStoppingCallback — Stop on plateau.
- ProgressBarCallback — Rich progress bar during training.
- LRMonitorCallback — Log learning rate each step.
Metrics
AxiomML tracks standard metrics automatically. Custom metrics can be added:
| Metric | Description | When logged |
|---|---|---|
| loss | Training loss (mean over batches). | Each log step |
| val_loss | Validation loss. | Each validation |
| acc / val_acc | Accuracy (classification). | Each validation |
| f1 / val_f1 | F1 score (macro). | Each validation |
| lr | Current learning rate. | Each log step |
| grad_norm | Global gradient norm. | Each log step |
| epoch_time | Time per epoch (seconds). | Each epoch |
| throughput | Samples per second. | Each log step |
Export & Inference
After training, export models for production deployment:
| Method | Format | Description |
|---|---|---|
exp.export("model.pt") | TorchScript | Script model via torch.jit.script or torch.jit.trace. |
exp.export("model.onnx", format="onnx") | ONNX | Export via torch.onnx.export with dynamic axes support. |
exp.export("model.pth", format="state_dict") | State dict | Pure model.state_dict() without architecture wrapper. |
Inference example:
import torch
from axioml import Experiment
exp = Experiment("config.yaml")
exp.load_checkpoint("checkpoints/best.pt")
# Single prediction
x = torch.randn(1, 3, 32, 32) # example input
pred = exp.predict(x)
# Batch prediction
batch = torch.randn(16, 3, 32, 32)
predictions = exp.predict(batch)
# predictions: torch.Tensor of shape (16, num_classes)
exp.export("model.onnx", format="onnx") for deployment to ONNX Runtime, TensorRT, or mobile. The exported model is self-contained and does not require AxiomML to run.