import pytorch_lightning as pl
import torch.nn as nn
Pytorch Lightning
- It is a framework that simplifies the code needed for train, evaluate and test the models
- It manages logging to TensorBoard
- It manages saving model checkpoints with minimal code overhead
Import the Library
# setting the seed
42) pl.seed_everything(
Global seed set to 42
42
Pytorch lightning code organization
### Example of an Lightning module
class CIFARModule(pl.LightningModule):
def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):
"""
Inputs:
model_name - Name of the model/CNN to run. Used for creating the model (see function below)
model_hparams - Hyperparameters for the model, as dictionary.
optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
"""
super().__init__()
# Exports the hyperparameters to a YAML file, and create 'self.hparams' namespace
self.save_hyperparameters()
# Create model
self.model = create_model(model_name, model_hparams)
# create loss module
self.loss_module = nn.CrossEntropyLoss()
# Example Input for visualizing the graph in Tensorboard
self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)
def forward(self, imgs):
return self.model(imgs)
def configure_optimizers(self):
if self.hparams.optimizer_name == 'Adam':
= optim.AdamW(
optimizer self.parameters(), **self.hparams.optimizer_hparams)
elif self.hparams.optimizer_name == 'SGD':
= optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)
optimizer else:
assert False, f"Unknown optimizer: \"{self.hparams.optimizer_name}\""
# Reduce the learning rate by 0.1 after 100 and 150 epochs
= optim.lr_scheduler.MultiStepLR(
scheduler =[100, 150], gamma=0.1)
optimizer, milestonesreturn [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
= batch
imgs, labels = self.model(imgs)
preds = self.loss_module(preds,labels)
loss = (preds.argmax(dim=-1) == labels).float().mean()
acc
# Log accuracy per epoch to tensorboard
self.log('train_acc', acc, on_step=False, on_epoch=True)
self.log('train_loss',loss)
return loss # Return tensor to call .backward()
def validation_step(self, batch, batch_idx):
= batch
imgs, labels = self.model(imgs)
preds = self.loss_module(preds,labels)
loss = (preds == labels).float().mean()
acc self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
= batch
imgs, labels = self.model(imgs).argmax(dim=-1)
preds = (labels == preds).float().mean()
acc self.log('test_acc', acc)
# callbacks
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
= {}
model_dict
def create_model(model_name, model_hparams):
if model_name in model_dict:
return model_dict[model_name](**model_hparams)
else:
assert False, f"Unknown model name \"{model_name}\". Available models are: {str(model_dict.keys())}"
= {
act_fn_by_name "tanh": nn.Tanh,
"relu": nn.ReLU,
"leakyrelu": nn.LeakyReLU,
"gelu": nn.GELU
}
def train_model(model_name, save_name=None, **kwargs):
"""
Inputs:
model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
"""
if save_name is None:
= model_name
save_name
# Create a PyTorch Lightning trainer with the generation callback
= pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name), # Where to save models
trainer ="gpu" if str(device).startswith("cuda") else "cpu", # We run on a GPU (if possible)
accelerator=1, # How many GPUs/CPUs we want to use (1 is enough for the notebooks)
devices=180, # How many epochs to train for if no patience is set
max_epochs=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"), # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
callbacks"epoch")], # Log learning rate every epoch
LearningRateMonitor(=True) # Set to False if you do not want a progress bar
enable_progress_bar= True # If True, we plot the computation graph in tensorboard
trainer.logger._log_graph = None # Optional logging argument that we don't need
trainer.logger._default_hp_metric
# Check whether pretrained model exists. If yes, load it and skip training
= os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
pretrained_filename if os.path.isfile(pretrained_filename):
print(f"Found pretrained model at {pretrained_filename}, loading...")
= CIFARModule.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
model else:
42) # To be reproducable
pl.seed_everything(= CIFARModule(model_name=model_name, **kwargs)
model
trainer.fit(model, train_loader, val_loader)= CIFARModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training
model
# Test best model on validation and test set
= trainer.test(model, val_loader, verbose=False)
val_result = trainer.test(model, test_loader, verbose=False)
test_result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}
result
return model, result