"""
Core class for micromind. Supports helper function for exports. Out-of-the-box
multi-gpu and FP16 training with HF Accelerate and much more.
Authors:
- Francesco Paissan, 2023
"""
from abc import ABC, abstractmethod
from argparse import Namespace
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from accelerate import Accelerator
from tqdm import tqdm
import warnings
from .utils.helpers import get_logger
logger = get_logger()
# This is used ONLY if you are not using argparse to get the hparams
default_cfg = {
"output_folder": "results",
"experiment_name": "micromind_exp",
"opt": "adam", # this is ignored if you are overriding the configure_optimizers
"lr": 0.001, # this is ignored if you are overriding the configure_optimizers
"debug": False,
}
[docs]
@dataclass
class Stage:
"""enum to track training stage"""
train: int = 0
val: int = 1
test: int = 2
[docs]
class Metric:
"""
Class for tracking evaluation metrics during training.
This class allows you to create custom evaluation metrics by providing a
function to compute the metric and specifying a reduction method.
Arguments
---------
name : str
The name of the metric.
fn : Callable
A function that computes the metric given predictions and batch data.
reduction : Optional[str]
The reduction method for the metric ('sum' or 'mean'). Default is 'mean'.
Returns
-------
Reduced metric. Optionally, you can access the metric history
before call reduce(clear=True) : torch.Tensor
Example
-------
.. doctest::
>>> from micromind import Metric, Stage
>>> import torch
>>> def custom_metric(pred, batch):
... # Replace this with your custom metric calculation
... return pred - batch
>>> metric = Metric("Custom Metric", custom_metric, reduction="mean")
>>> pred = torch.tensor([1.0, 2.0, 3.0])
>>> batch = torch.tensor([0.5, 1.5, 2.5])
>>> metric(pred, batch, stage=Stage.train)
>>> metric.history
{0: [tensor([0.5000, 0.5000, 0.5000])], 1: [], 2: []}
>>> metric.reduce(Stage.train)
0.5
"""
def __init__(
self,
name: str,
fn: Callable,
reduction: Optional[str] = "mean",
eval_only: Optional[bool] = False,
eval_period: Optional[int] = 1,
):
self.name = name
self.fn = fn
self.reduction = reduction
self.eval_only = eval_only
self.eval_period = eval_period
self.history = {s: [] for s in [Stage.train, Stage.val, Stage.test]}
def __call__(self, pred, batch, stage, device="cpu"):
dat = self.fn(pred, batch)
if dat.ndim == 0:
dat = dat.unsqueeze(0)
self.history[stage].append(dat)
[docs]
def reduce(self, stage, clear=False):
"""
Compute and return the metric for a given prediction and batch data.
Arguments
---------
pred : torch.Tensor
The model's prediction.
batch : torch.Tensor
The ground truth or target values.
stage : Stage
The current stage (e.g., Stage.train).
device Optional[str]
The device on which to perform the computation. Default is 'cpu'.
"""
if self.reduction == "mean":
tmp = torch.cat(self.history[stage], dim=0).mean()
elif self.reduction == "sum":
tmp = torch.cat(self.history[stage], dim=0).sum()
if clear:
self.history[stage] = []
return tmp.item()
[docs]
class MicroMind(ABC):
"""
MicroMind is an abstract base class for creating and training deep learning
models. Handles training on multi-gpu via accelerate (using DDP and other
distributed training strategies). It automatically handles the device
management for the training and the micromind's export capabilities to onnx,
OpenVino and TFLite.
Arguments
---------
hparams : Optional[Namespace]
Hyperparameters for the model. Default is None.
"""
def __init__(self, hparams=None):
if hparams is None:
hparams = Namespace(**default_cfg)
# here we should handle devices etc.
self.modules = torch.nn.ModuleDict({}) # init empty modules dict
self.hparams = hparams
self.input_shape = None
self.accelerator = Accelerator()
self.device = self.accelerator.device
self.current_epoch = 0
[docs]
@abstractmethod
def forward(self, batch):
"""
Forward step of the class. It gets called during inference and optimization.
This method should be overwritten for specific applications.
Arguments
---------
batch : torch.Tensor
Batch as output from the defined DataLoader.
Returns
-------
pred : Union[torch.Tensor, Tuple]
Predictions - this depends on the task.
"""
pass
[docs]
@abstractmethod
def compute_loss(self, pred, batch):
"""
Computes the cost function for the optimization process. It return a
tensor on which backward() is called. This method should be overwritten
for the specific application.
Arguments
---------
pred : Union[torch.Tensor, Tuple]
Output of the forward() function
batch : torch.Tensor
Batch as defined from the DataLoader.
Returns
-------
loss : torch.Tensor
Compute cost function.
"""
pass
[docs]
def load_modules(self, checkpoint_path: Union[Path, str]):
"""Loads models for path.
Arguments
---------
checkpoint_path : Union[Path, str]
Path to the checkpoint where the modules are stored.
"""
dat = torch.load(checkpoint_path)
modules_keys = list(self.modules.keys())
for k in self.modules:
self.modules[k].load_state_dict(dat[k])
modules_keys.remove(k)
if len(modules_keys) != 0:
logger.info(f"Couldn't find a state_dict for modules {modules_keys}.")
[docs]
def export(
self, save_dir: Union[Path, str], out_format: str = "onnx", input_shape=None
) -> None:
"""
Export the model to a specified format for deployment.
TFLite and OpenVINO need a Linux machine to be exported.
Arguments
---------
save_dir : Union[Path, str]
The directory where the exported model will be saved.
out_format : Optional[str]
The format for exporting the model. Default is 'onnx'.
input_shape : Optional[Tuple]
The input shape of the model. If not provided, the input shape
specified during model creation is used.
"""
from micromind import convert
if not isinstance(save_dir, Path):
save_dir = Path(save_dir)
save_dir = save_dir.joinpath(self.hparams.experiment_name)
self.set_input_shape(input_shape)
assert (
self.input_shape is not None
), "You should pass the input_shape of the model."
if out_format == "onnx":
convert.convert_to_onnx(
self, save_dir.joinpath("model.onnx"), replace_forward=True
)
elif out_format == "openvino":
convert.convert_to_openvino(self, save_dir, replace_forward=True)
elif out_format == "tflite":
convert.convert_to_tflite(self, save_dir, replace_forward=True)
def __call__(self, *x, **xv):
"""Just forwards everything to the forward method."""
return self.forward(*x, **xv)
[docs]
def on_train_start(self):
"""Initializes the optimizer, modules and puts the networks on the right
devices. Optionally loads checkpoint if already present.
This function gets executed at the beginning of every training.
"""
# pass debug status to checkpointer
self.checkpointer.debug = self.hparams.debug
init_opt = self.configure_optimizers()
if isinstance(init_opt, list) or isinstance(init_opt, tuple):
self.opt, self.lr_sched = init_opt
else:
self.opt = init_opt
self.init_devices()
self.start_epoch = 0
if self.checkpointer is not None:
# recover state
ckpt = self.checkpointer.recover_state()
if ckpt is not None:
accelerate_path, self.start_epoch = ckpt
self.accelerator.load_state(accelerate_path)
else:
tmp = """
You are not passing a checkpointer to the training function, \
thus no status will be saved. If this is not the intended behaviour \
please check https://micromind-toolkit.github.io/docs/").
"""
warnings.warn(" ".join(tmp.split()))
[docs]
def init_devices(self):
"""Initializes the data pipeline and modules for DDP and accelerated inference.
To control the device selection, use `accelerate config`."""
convert = [self.modules]
if hasattr(self, "opt"):
convert += [self.opt]
if hasattr(self, "lr_sched"):
convert += [self.lr_sched]
if hasattr(self, "datasets"):
# if the datasets are store here, prepare them for DDP
convert += list(self.datasets.values())
accelerated = self.accelerator.prepare(convert)
self.modules = accelerated[0]
self.accelerator.register_for_checkpointing(self.modules)
if hasattr(self, "opt"):
self.opt = accelerated[1]
self.accelerator.register_for_checkpointing(self.opt)
if hasattr(self, "lr_sched"):
self.lr_sched = accelerated[2]
self.accelerator.register_for_checkpointing(self.lr_sched)
if hasattr(self, "datasets"):
for i, key in enumerate(list(self.datasets.keys())[::-1]):
self.datasets[key] = accelerated[-(i + 1)]
self.modules.to(self.device)
[docs]
def on_train_end(self):
"""Runs at the end of each training. Cleans up before exiting."""
pass
[docs]
def eval(self):
self.modules.eval()
[docs]
def train(
self,
epochs: int = 1,
datasets: Dict = {},
metrics: List[Metric] = [],
checkpointer=None, # fix type hints
debug: bool = False,
) -> None:
"""
This method trains the model on the provided training dataset for the
specified number of epochs. It tracks training metrics and can
optionally perform validation during training, if the validation set is
provided.
Arguments
---------
epochs : int
The number of training epochs.
datasets : Dict
A dictionary of dataset loaders. Dataloader should be mapped to keys
"train", "val", and "test".
metrics : Optional[List[Metric]]
A list of metrics to track during training. Default is an empty list.
debug : bool
Whether to run in debug mode. Default is False. If in debug mode,
only runs for few epochs
and with few batches.
"""
self.datasets = datasets
self.metrics = metrics
self.checkpointer = checkpointer
assert "train" in self.datasets, "Training dataloader was not specified."
assert epochs > 0, "You must specify at least one epoch."
self.debug = debug
self.on_train_start()
if self.accelerator.is_local_main_process:
logger.info(
f"Starting from epoch {self.start_epoch + 1}."
+ f" Training is scheduled for {epochs} epochs."
)
with self.accelerator.autocast():
for e in range(self.start_epoch + 1, epochs + 1):
self.current_epoch = e
pbar = tqdm(
self.datasets["train"],
unit="batches",
ascii=True,
dynamic_ncols=True,
disable=not self.accelerator.is_local_main_process,
)
loss_epoch = 0
pbar.set_description(f"Running epoch {self.current_epoch}/{epochs}")
self.modules.train()
for idx, batch in enumerate(pbar):
if isinstance(batch, list):
batch = [b.to(self.device) for b in batch]
self.opt.zero_grad()
model_out = self(batch)
loss = self.compute_loss(model_out, batch)
loss_epoch += loss.item()
self.accelerator.backward(loss)
self.opt.step()
if hasattr(self, "lr_sched"):
# ok for cos_lr
self.lr_sched.step()
for m in self.metrics:
if (
self.current_epoch + 1
) % m.eval_period == 0 and not m.eval_only:
m(model_out, batch, Stage.train, self.device)
running_train = {}
for m in self.metrics:
if (
self.current_epoch + 1
) % m.eval_period == 0 and not m.eval_only:
running_train["train_" + m.name] = m.reduce(Stage.train)
running_train.update({"train_loss": loss_epoch / (idx + 1)})
pbar.set_postfix(**running_train)
if self.debug and idx > 10:
break
pbar.close()
train_metrics = {}
for m in self.metrics:
if (
self.current_epoch + 1
) % m.eval_period == 0 and not m.eval_only:
train_metrics["train_" + m.name] = m.reduce(Stage.train, True)
train_metrics.update({"train_loss": loss_epoch / (idx + 1)})
if "val" in datasets:
val_metrics = self.validate()
if (
self.accelerator.is_local_main_process
and self.checkpointer is not None
):
self.checkpointer(
self,
train_metrics,
val_metrics,
)
else:
val_metrics = train_metrics.update(
{"val_loss": loss_epoch / (idx + 1)}
)
if e >= 1 and self.debug:
break
self.on_train_end()
return None
[docs]
@torch.no_grad()
def validate(self) -> Dict:
"""Runs the validation step."""
assert "val" in self.datasets, "Validation dataloader was not specified."
self.modules.eval()
pbar = tqdm(
self.datasets["val"],
unit="batches",
ascii=True,
dynamic_ncols=True,
disable=not self.accelerator.is_local_main_process,
)
loss_epoch = 0
pbar.set_description("Validation...")
with self.accelerator.autocast():
for idx, batch in enumerate(pbar):
if isinstance(batch, list):
batch = [b.to(self.device) for b in batch]
self.opt.zero_grad()
model_out = self(batch)
loss = self.compute_loss(model_out, batch)
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0:
m(model_out, batch, Stage.val, self.device)
loss_epoch += loss.item()
pbar.set_postfix(loss=loss_epoch / (idx + 1))
if self.debug and idx > 10:
break
val_metrics = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0:
val_metrics["val_" + m.name] = m.reduce(Stage.val, True)
val_metrics.update({"val_loss": loss_epoch / (idx + 1)})
pbar.close()
return val_metrics
[docs]
@torch.no_grad()
def test(self, datasets: Dict = {}, metrics: List[Metric] = []) -> None:
"""Runs the test steps.
Arguments
---------
datasets : Dict
Dictionary with the test DataLoader. Should be present in the key
`test`.
metrics : List[Metric]
List of metrics to compute during test step.
Returns
-------
Metrics computed on test set. : Dict[torch.Tensor]
"""
assert "test" in datasets, "Test dataloader was not specified."
self.modules.eval()
pbar = tqdm(
datasets["test"],
unit="batches",
ascii=True,
dynamic_ncols=True,
disable=not self.accelerator.is_local_main_process,
)
loss_epoch = 0
pbar.set_description("Testing...")
with self.accelerator.autocast():
for idx, batch in enumerate(pbar):
if isinstance(batch, list):
batch = [b.to(self.device) for b in batch]
model_out = self(batch)
loss = self.compute_loss(model_out, batch)
for m in metrics:
m(model_out, batch, Stage.test, self.device)
loss_epoch += loss.item()
pbar.set_postfix(loss=loss_epoch / (idx + 1))
pbar.close()
test_metrics = {"test_" + m.name: m.reduce(Stage.test, True) for m in metrics}
test_metrics.update({"test_loss": loss_epoch / (idx + 1)})
s_out = (
"Testing "
+ " - ".join([f"{k}: {v:.2f}" for k, v in test_metrics.items()])
+ "; "
)
logger.info(s_out)
return test_metrics