Higher-level PyTorch APIs: A short introduction to PyTorch Lightning

Setting up the PyTorch Lightning model

import pytorch_lightning as pl 
import torch
import torch.nn as nn
from torchmetrics import Accuracy class MultiLayerPerceptron(pl.LightningModule):
def __init__(self,image_shape=(1, 28, 28), hidden_units=(32, 16)):
super().__init__()
# new PL attributes:
self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()
# Model similar to previous section:
input_size = image_shape[0] * image_shape[1] * image_shape[2]
all_layers = [nn.Flatten()]
for hidden_unit in hidden_units:
layer = nn.Linear(input_size, hidden_unit)
all_layers.append(layer)
all_layers.append(nn.ReLU())
input_size = hidden_unit
all_layers.append(nn.Linear(hidden_units[-1], 10))
all_layers.append(nn.Softmax(dim=1))
self.model = nn.Sequential(*all_layers)
def forward(self, x):
x = self.model(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(self(x), y)
preds = torch.argmax(logits, dim=1)
self.train_acc.update(preds, y)
self.log("train_loss", loss, prog_bar=True)
return loss
def training_epoch_end(self, outs):
self.log("train_acc", self.train_acc.compute())
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(self(x), y)
preds = torch.argmax(logits, dim=1)
self.valid_acc.update(preds, y)
self.log("valid_loss", loss, prog_bar=True)
self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(self(x), y)
preds = torch.argmax(logits, dim=1)
self.test_acc.update(preds, y)
self.log("test_loss", loss, prog_bar=True)
self.log("test_acc", self.test_acc.compute(), prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer

Setting up the data loaders for Lightning

  • Make the dataset part of the model
  • Set up the data loaders as usual and feed them to the fit method of a Lightning Trainer — the Trainer is introduced in the next subsection
  • Create a LightningDataModule
from torch.utils.data import DataLoader 
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
class MnistDataModule(pl.LightningDataModule):
def __init__(self, data_path='./'):
super().__init__()
self.data_path = data_path
self.transform = transforms.Compose([transforms.ToTensor()])
def prepare_data(self):
MNIST(root=self.data_path, download=True)
def setup(self, stage=None):
# stage is either 'fit', 'validate', 'test', or 'predict'
# here note relevant
mnist_all = MNIST(
root=self.data_path,
train=True,
transform=self.transform,
download=False
)
self.train, self.val = random_split(
mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
)
self.test = MNIST(
root=self.data_path,
train=False,
transform=self.transform,
download=False
)
def train_dataloader(self):
return DataLoader(self.train, batch_size=64, num_workers=4)
def val_dataloader(self):
return DataLoader(self.val, batch_size=64, num_workers=4)
def test_dataloader(self):
return DataLoader(self.test, batch_size=64, num_workers=4)
torch.manual_seed(1)  
mnist_dm = MnistDataModule()

Training the model using the PyTorch Lightning Trainer class

mnistclassifier = MultiLayerPerceptron()  if torch.cuda.is_available(): # if you have GPUs 
trainer = pl.Trainer(max_epochs=10, gpus=1)
else:
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=mnistclassifier, datamodule=mnist_dm)
tensorboard --logdir lightning_logs/
%load_ext tensorboard 
%tensorboard --logdir lightning_logs/

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
ODSC - Open Data Science

ODSC - Open Data Science

93K Followers

Our passion is bringing thousands of the best and brightest data scientists together under one roof for an incredible learning and networking experience.