Documentation

Complete reference for the AxiomML experiment framework. All configuration parameters, API methods, and advanced usage covered.

Overview

AxiomML is a modular, config-driven ML experiment framework built on PyTorch. It separates experiment logic from configuration, enabling reproducible runs with zero boilerplate. Features include:

Installation

# From PyPI
pip install axioml

# From source (latest)
pip install git+https://github.com/anomalyco/AxiomML.git

Requires Python 3.10+ and PyTorch 2.0+. CUDA is optional but recommended.

Quick Start

Create a config file and run an experiment in three steps:

# 1. Create config.yaml
model: mlp
dataset: mnist
trainer:
  epochs: 5
  batch_size: 64
  learning_rate: 0.001
# 2. Write train.py
from axioml import Experiment

exp = Experiment("config.yaml")
exp.train()
metrics = exp.evaluate()
# 3. Run
python train.py

Configuration System

Configs are YAML files (or JSON) that define every aspect of an experiment. The system supports:

Below is the complete parameter reference organized by section.

Experiment Config

Top-level experiment settings.

ParameterTypeDefaultDescription
namestring"experiment"Experiment name (used in W&B run name & checkpoint dirs).
seedint42Global random seed (Python, NumPy, PyTorch, CUDA).
devicestring"cuda"Device: "cuda", "cpu", or "auto" (auto picks CUDA if available).
mixed_precisionstring"fp16"Mixed precision mode: "fp16", "bf16", or "no".
output_dirstring"./outputs"Root directory for all experiment outputs.
resume_fromstringnullPath to checkpoint to resume from. Auto-loads state dict, optimizer, epoch.

Data / Dataset Config

Configure dataset loading, transforms, batching, and splitting.

ParameterTypeDefaultDescription
data.datasetstringDataset name (see Dataset Registry).
data.rootstring"./data"Root directory to cache/download datasets.
data.transformslist[]List of transform names or configs (e.g. ["ToTensor", "Normalize"]).
data.splitsobject{train: 0.8, val: 0.1, test: 0.1}Dataset split ratios or named split configs.
data.batch_sizeint64Batch size per device (auto-scaled in distributed mode).
data.num_workersint4DataLoader worker processes.
data.shufflebooltrueShuffle training data each epoch.
data.pin_memorybooltruePin memory in DataLoader for faster GPU transfer.
data.drop_lastboolfalseDrop last incomplete batch.
data.persistent_workersbooltrueKeep worker processes alive between epochs.
data.stratifyboolfalseStratified train/val split by labels.
data.cachebooltrueCache preprocessed dataset to disk.
data.subsetintnullUse only N samples (useful for debugging).

Model Config

Select model architecture and load pretrained weights.

ParameterTypeDefaultDescription
model.namestringModel architecture name (see Model Zoo).
model.paramsobject{}Architecture-specific kwargs (e.g., hidden_dims, num_layers, dropout).
model.pretrainedbool|stringfalseLoad pretrained weights. Set to true for default, or a path/URL.
model.freezeboolfalseFreeze all model parameters (for transfer learning / feature extraction).
model.compileboolfalseUse torch.compile for graph-optimized execution.
model.ema_decayfloatnullExponential Moving Average decay rate. If set, maintains EMA weights.

Optimizer Config

ParameterTypeDefaultDescription
optimizer.namestring"adam"Optimizer: "adam", "adamw", "sgd", "rmsprop", "adamax", "nadam".
optimizer.lrfloat0.001Learning rate.
optimizer.weight_decayfloat0.0Weight decay (L2 penalty).
optimizer.momentumfloat0.9Momentum factor (SGD/RMSProp).
optimizer.betas[float,float][0.9, 0.999]Adam betas.
optimizer.epsfloat1e-8Numerical stability epsilon.
optimizer.grad_clip_normfloatnullMax gradient norm for clipping (null = no clip).
optimizer.grad_clip_valuefloatnullMax gradient value for clipping (alternative to norm).

Scheduler Config

ParameterTypeDefaultDescription
scheduler.namestring"cosine"LR scheduler: "cosine", "step", "multistep", "exponential", "plateau", "onecycle", "warmup_cosine", "linear".
scheduler.paramsobject{}Scheduler-specific kwargs (e.g., step_size, gamma, milestones, T_max, eta_min, patience).
scheduler.warmup_epochsint0Linear warmup epochs before scheduler starts.
scheduler.warmup_lrfloat1e-6Starting LR during warmup.
scheduler.intervalstring"epoch"Step frequency: "epoch" or "step" (batch).

Trainer Config

ParameterTypeDefaultDescription
trainer.epochsint10Number of training epochs.
trainer.max_stepsintnullMax training steps (overrides epochs if set).
trainer.batch_sizeintnullPer-device batch size (overrides data.batch_size for training).
trainer.grad_accumulationint1Gradient accumulation steps (effective batch = batch_size * accum).
trainer.validate_everyint1Validate every N epochs (or N steps if interval="step").
trainer.save_everyint1Save checkpoint every N epochs.
trainer.log_everyint10Log metrics every N training steps.
trainer.early_stop_patienceintnullStop training if validation metric does not improve for N epochs.
trainer.early_stop_metricstring"val_loss"Metric to monitor for early stopping.
trainer.early_stop_modestring"min""min" or "max" — whether lower or higher is better.
trainer.verbosebooltruePrint progress bar and epoch summaries.
trainer.profilerstringnullProfiler: "simple", "advanced", or null. Records timing per component.
trainer.detect_anomalyboolfalseEnable PyTorch autograd anomaly detection (slower, for debugging).

Logging Config

ParameterTypeDefaultDescription
logging.projectstring"axioml"W&B project name.
logging.run_namestringnullW&B run name (auto-generated from config hash if not set).
logging.tags[string][]List of tags for W&B run organization.
logging.notesstringnullFree-text notes for the run.
logging.log_imagesboolfalseLog sample images/grids to W&B (useful for vision tasks).
logging.log_histogramsboolfalseLog weight/gradient histograms each epoch.
logging.log_modelbooltrueUpload model checkpoints as W&B artifacts.
logging.watch_modelboolfalseEnable W&B model watch (gradients & parameters).
logging.watch_freqint100Frequency (steps) for model watch logging.
logging.console_formatstring"rich"Console output: "rich" (Rich tables) or "simple" (plain text).
logging.quietboolfalseSuppress non-essential console output.

Checkpoint Config

ParameterTypeDefaultDescription
checkpoint.dirstring"./checkpoints"Directory to save checkpoints (relative to output_dir).
checkpoint.top_kint3Keep only top-K checkpoints by metric (cleans older ones).
checkpoint.metricstring"val_loss"Metric for ranking checkpoints.
checkpoint.modestring"min""min" or "max" — lower/higher metric is better.
checkpoint.save_bestbooltrueSave best checkpoint separately as best.pt.
checkpoint.save_lastbooltrueSave last checkpoint separately as last.pt.
checkpoint.save_optimizerbooltrueInclude optimizer state in checkpoint (required for resume).
checkpoint.save_schedulerbooltrueInclude scheduler state in checkpoint.
checkpoint.filename_templatestring"epoch={epoch}-{metric}={value:.4f}"Checkpoint filename pattern.
checkpoint.clean_on_startboolfalseRemove previous checkpoints in dir before starting.

Distributed Config

ParameterTypeDefaultDescription
distributed.enabledboolfalseEnable distributed training via DDP.
distributed.backendstring"nccl"Distributed backend: "nccl" (GPU), "gloo" (CPU), "mpi".
distributed.world_sizeintnullNumber of processes (auto-detected from torchrun if null).
distributed.master_addrstring"localhost"Master node address for distributed init.
distributed.master_portint29500Master node port for distributed init.
distributed.find_unused_parametersboolfalseEnable DDP find_unused_parameters (slower but avoids hangs).
distributed.sync_bnboolfalseReplace BatchNorm with SyncBatchNorm across devices.
distributed.gradient_checkpointingboolfalseEnable gradient checkpointing (trade compute for memory).
distributed.mixed_precisionstringnullOverride top-level mixed_precision for distributed (defaults to top-level value).
Launch with torchrun: torchrun --nproc_per_node=4 train.py --config distributed.enabled=true

Full Example Config

# configs/full_example.yaml
# Top-level
name: my_experiment
seed: 42
device: cuda
mixed_precision: fp16
output_dir: ./outputs

# Data
data:
  dataset: cifar10
  root: ./data
  transforms:
    - RandomHorizontalFlip
    - RandomCrop: {size: 32, padding: 4}
    - ToTensor
    - Normalize: {mean: [0.5, 0.5, 0.5], std: [0.5, 0.5, 0.5]}
  batch_size: 128
  num_workers: 8
  shuffle: true
  splits:
    train: 0.8
    val: 0.1
    test: 0.1

# Model
model:
  name: resnet18
  params:
    num_classes: 10
    pretrained: false
  compile: true

# Optimizer
optimizer:
  name: adamw
  lr: 0.0003
  weight_decay: 0.01
  betas: [0.9, 0.95]

# Scheduler
scheduler:
  name: warmup_cosine
  params:
    T_max: 100
    eta_min: 1e-6
  warmup_epochs: 5

# Trainer
trainer:
  epochs: 100
  grad_accumulation: 2
  validate_every: 1
  early_stop_patience: 20

# Logging
logging:
  project: axioml-demo
  tags: [resnet18, cifar10, adamw]
  log_images: true
  log_histograms: true

# Checkpoint
checkpoint:
  dir: ./checkpoints
  top_k: 3
  metric: val_acc
  mode: max

# Distributed
distributed:
  enabled: false
  sync_bn: false

Experiment API

The Experiment class is the main entry point. It accepts a config (path, dict, or parsed config object).

Constructor

Experiment(
    config: str | dict | Config,
    distributed: bool = False,
    resume: str | None = None,
)

Methods

MethodReturnsDescription
.train()NoneRun the full training loop. Logs to W&B, saves checkpoints.
.evaluate(split="test")dictEvaluate on given split. Returns {loss, acc, f1, ...}.
.predict(x)TensorRun inference on a single input or batch.
.export(path, format="torchscript")NoneExport model to TorchScript or ONNX.
.save_checkpoint(path=None)strSave a checkpoint manually. Returns path.
.load_checkpoint(path)NoneLoad checkpoint (model, optimizer, scheduler state).
.log_metrics(metrics, step=None)NoneLog custom metrics to W&B and console.
.log_artifact(path, type="model")NoneLog an arbitrary file as a W&B artifact.
.summary()dictReturn run summary with all final metrics.
.get_model()nn.ModuleReturn the underlying PyTorch model.
.get_optimizer()OptimizerReturn the optimizer instance.
.get_scheduler()SchedulerReturn the LR scheduler instance.
.get_train_loader()DataLoaderReturn the training DataLoader.
.get_val_loader()DataLoaderReturn the validation DataLoader.
.get_test_loader()DataLoaderReturn the test DataLoader.

Trainer API

The Trainer handles the training loop, validation, checkpointing, and logging. Accessible via exp.trainer.

MethodDescription
train()Execute full training loop with callbacks.
train_epoch()Run a single training epoch.
validate()Run full validation loop.
predict(batch)Forward pass without gradient tracking.
get_lr()Return current learning rate(s).

Properties

PropertyTypeDescription
current_epochintCurrent epoch (0-indexed).
global_stepintTotal optimizer steps taken.
best_metricfloatBest validation metric value so far.
best_epochintEpoch at which best metric was achieved.
is_distributedboolWhether running in DDP mode.
devicetorch.deviceCurrent device (rank-aware in DDP).

Data / Dataset API

Built-in Datasets

Dataset IDDescriptionType
mnistMNIST handwritten digits (28x28, 10 classes).Vision
fashion_mnistFashion-MNIST clothing items (28x28, 10 classes).Vision
cifar10CIFAR-10 tiny images (32x32, 10 classes).Vision
cifar100CIFAR-100 tiny images (32x32, 100 classes).Vision
svhnSVHN street view house numbers (32x32, 10 classes).Vision
imagenetImageNet subset or full (configurable root).Vision
text8Text8 character-level language modeling.Text
sst2SST-2 sentiment classification.Text

Custom Datasets

Register any torch.utils.data.Dataset subclass:

from axioml.data import register_dataset

@register_dataset("my_dataset")
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, split="train", transforms=None):
        # ...
        pass
    def __len__(self): pass
    def __getitem__(self, idx): pass

Transforms

Transforms are specified by name in config. Built-in transforms include:

TransformDescription
ToTensorConvert PIL/ndarray to tensor.
NormalizeNormalize with mean/std.
RandomHorizontalFlipRandom horizontal flip (p=0.5).
RandomVerticalFlipRandom vertical flip (p=0.5).
RandomCropRandom crop with optional padding.
RandomRotationRandom rotation by degrees.
ColorJitterRandom brightness/contrast/saturation/hue.
ResizeResize to target dimensions.
CenterCropCenter crop to target size.

Models (Model Zoo)

AxiomML includes a registry of built-in model architectures. Custom models can be registered similarly to datasets.

Built-in Models

Model IDDescriptionParams (config.model.params)
mlpMulti-layer perceptron.hidden_dims, dropout, activation
lenetLeNet-5 for small images.num_classes
resnet18ResNet-18.num_classes, pretrained
resnet34ResNet-34.num_classes, pretrained
resnet50ResNet-50.num_classes, pretrained
vit_tinyVision Transformer (Tiny).img_size, patch_size, embed_dim, depth, num_heads, num_classes
vit_smallVision Transformer (Small).Same as vit_tiny with larger defaults.
vit_baseVision Transformer (Base).Same as vit_tiny with base defaults.

Registering Custom Models

from axioml.models import register_model

@register_model("my_net")
class MyNet(nn.Module):
    def __init__(self, num_classes=10, **kwargs):
        super().__init__()
        # ...

    def forward(self, x):
        return x

Logging API

AxiomML uses a Logger abstraction that writes to multiple backends simultaneously: W&B, console (Rich), and file.

MethodDescription
log_metrics(metrics, step)Log scalar metrics (loss, acc, lr, etc.).
log_image(key, image, step)Log a single image or tensor.
log_images(key, images, step)Log a grid of images.
log_histogram(key, values, step)Log a histogram of values.
log_figure(key, figure, step)Log a matplotlib figure.
log_text(key, text, step)Log text/table data.
log_artifact(path, type)Log any file as an artifact.
log_hyperparameters(config)Log all config params as W&B hyperparameters.
watch(model, log_freq)Start W&B model watch (gradients & parameters).
unwatch()Stop model watch.
finish()Finalize all loggers (upload, close).

Callbacks

AxiomML supports a callback system for hooking into the training lifecycle. Callbacks are classes that implement one or more of the following methods:

HookWhen it fires
on_train_start(trainer)Before the training loop begins.
on_train_end(trainer)After the training loop ends.
on_epoch_start(trainer)Start of each epoch.
on_epoch_end(trainer)End of each epoch.
on_batch_start(trainer, batch)Before each training batch.
on_batch_end(trainer, outputs)After each training batch (loss computed).
on_validation_start(trainer)Before validation loop.
on_validation_end(trainer, metrics)After validation loop with metrics dict.
on_checkpoint_save(trainer, path)After a checkpoint is saved.
on_log(trainer, metrics, step)On each logging step.
from axioml.callbacks import Callback

class MyCallback(Callback):
    def on_epoch_end(self, trainer):
        print(f"Epoch {trainer.current_epoch} done")

Built-in callbacks (auto-enabled):

Metrics

AxiomML tracks standard metrics automatically. Custom metrics can be added:

MetricDescriptionWhen logged
lossTraining loss (mean over batches).Each log step
val_lossValidation loss.Each validation
acc / val_accAccuracy (classification).Each validation
f1 / val_f1F1 score (macro).Each validation
lrCurrent learning rate.Each log step
grad_normGlobal gradient norm.Each log step
epoch_timeTime per epoch (seconds).Each epoch
throughputSamples per second.Each log step

Export & Inference

After training, export models for production deployment:

MethodFormatDescription
exp.export("model.pt")TorchScriptScript model via torch.jit.script or torch.jit.trace.
exp.export("model.onnx", format="onnx")ONNXExport via torch.onnx.export with dynamic axes support.
exp.export("model.pth", format="state_dict")State dictPure model.state_dict() without architecture wrapper.

Inference example:

import torch
from axioml import Experiment

exp = Experiment("config.yaml")
exp.load_checkpoint("checkpoints/best.pt")

# Single prediction
x = torch.randn(1, 3, 32, 32)  # example input
pred = exp.predict(x)

# Batch prediction
batch = torch.randn(16, 3, 32, 32)
predictions = exp.predict(batch)
# predictions: torch.Tensor of shape (16, num_classes)
Tip: Use exp.export("model.onnx", format="onnx") for deployment to ONNX Runtime, TensorRT, or mobile. The exported model is self-contained and does not require AxiomML to run.