import pytorch_lightning as pl
import torch.nn as nnPytorch Lightning
- It is a framework that simplifies the code needed for train, evaluate and test the models
- It manages logging to TensorBoard
- It manages saving model checkpoints with minimal code overhead
Import the Library
# setting the seed
pl.seed_everything(42)Global seed set to 42
42
Pytorch lightning code organization

### Example of an Lightning module class CIFARModule(pl.LightningModule):
def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):
"""
Inputs:
model_name - Name of the model/CNN to run. Used for creating the model (see function below)
model_hparams - Hyperparameters for the model, as dictionary.
optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
"""
super().__init__()
# Exports the hyperparameters to a YAML file, and create 'self.hparams' namespace
self.save_hyperparameters()
# Create model
self.model = create_model(model_name, model_hparams)
# create loss module
self.loss_module = nn.CrossEntropyLoss()
# Example Input for visualizing the graph in Tensorboard
self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)
def forward(self, imgs):
return self.model(imgs)
def configure_optimizers(self):
if self.hparams.optimizer_name == 'Adam':
optimizer = optim.AdamW(
self.parameters(), **self.hparams.optimizer_hparams)
elif self.hparams.optimizer_name == 'SGD':
optimizer = optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)
else:
assert False, f"Unknown optimizer: \"{self.hparams.optimizer_name}\""
# Reduce the learning rate by 0.1 after 100 and 150 epochs
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[100, 150], gamma=0.1)
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs)
loss = self.loss_module(preds,labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
# Log accuracy per epoch to tensorboard
self.log('train_acc', acc, on_step=False, on_epoch=True)
self.log('train_loss',loss)
return loss # Return tensor to call .backward()
def validation_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs)
loss = self.loss_module(preds,labels)
acc = (preds == labels).float().mean()
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs).argmax(dim=-1)
acc = (labels == preds).float().mean()
self.log('test_acc', acc)# callbacks
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpointmodel_dict = {}
def create_model(model_name, model_hparams):
if model_name in model_dict:
return model_dict[model_name](**model_hparams)
else:
assert False, f"Unknown model name \"{model_name}\". Available models are: {str(model_dict.keys())}"act_fn_by_name = {
"tanh": nn.Tanh,
"relu": nn.ReLU,
"leakyrelu": nn.LeakyReLU,
"gelu": nn.GELU
}def train_model(model_name, save_name=None, **kwargs):
"""
Inputs:
model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
"""
if save_name is None:
save_name = model_name
# Create a PyTorch Lightning trainer with the generation callback
trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name), # Where to save models
accelerator="gpu" if str(device).startswith("cuda") else "cpu", # We run on a GPU (if possible)
devices=1, # How many GPUs/CPUs we want to use (1 is enough for the notebooks)
max_epochs=180, # How many epochs to train for if no patience is set
callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"), # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
LearningRateMonitor("epoch")], # Log learning rate every epoch
enable_progress_bar=True) # Set to False if you do not want a progress bar
trainer.logger._log_graph = True # If True, we plot the computation graph in tensorboard
trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
# Check whether pretrained model exists. If yes, load it and skip training
pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
if os.path.isfile(pretrained_filename):
print(f"Found pretrained model at {pretrained_filename}, loading...")
model = CIFARModule.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
else:
pl.seed_everything(42) # To be reproducable
model = CIFARModule(model_name=model_name, **kwargs)
trainer.fit(model, train_loader, val_loader)
model = CIFARModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training
# Test best model on validation and test set
val_result = trainer.test(model, val_loader, verbose=False)
test_result = trainer.test(model, test_loader, verbose=False)
result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}
return model, result