pytorch-lightning 是建立在pytorch之上的高层次模型接口。
pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow.
通过使用 pytorch-lightning,用户无需编写自定义训练循环就可以非常简洁地在CPU、单GPU、多GPU、乃至多TPU上训练模型。
无需考虑模型和数据在cpu,cuda之间的移动,并且可以通过回调函数实现CheckPoint参数保存,实现断点续训功能。
一般按照如下方式 安装和 引入 pytorch-lightning 库。
#安装
pip install pytorch-lightning
#引入
import pytorch_lightning as pl
顾名思义,它可以帮助我们漂亮(pl)地进行深度学习研究。😋😋
pytorch-lightning 的核心设计哲学是将 深度学习项目中的 研究代码(定义模型) 和 工程代码 (训练模型) 相互分离。
用户只需专注于研究代码(pl.LightningModule)的实现,而工程代码借助训练工具类(pl.Trainer)统一实现。
更详细地说,深度学习项目代码可以分成如下4部分:
- 研究代码 (Research code),用户继承LightningModule实现。
- 工程代码 (Engineering code),用户无需关注通过调用Trainer实现。
- 非必要代码 (Non-essential research code,logging, etc...),用户通过调用Callbacks实现。
- 数据 (Data),用户通过torch.utils.data.DataLoader实现。
下面我们使用minist图片分类问题为例,演示pytorch-lightning的最佳实践。
1,准备数据
import torch
from torch import nn
import torchvision
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor()])
ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
ds_valid = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, batch_size=128, shuffle=False, num_workers=4)
print(len(ds_train))
print(len(ds_valid))
Done!
60000
10000
2,定义模型
import pytorch_lightning as pl
import datetime
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
nn.MaxPool2d(kernel_size = 2,stride = 2),
nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
nn.MaxPool2d(kernel_size = 2,stride = 2),
nn.Dropout2d(p = 0.1),
nn.AdaptiveMaxPool2d((1,1)),
nn.Flatten(),
nn.Linear(64,32),
nn.ReLU(),
nn.Linear(32,10)]
)
def forward(self,x):
for layer in self.layers:
x = layer(x)
return x
#定义loss,以及可选的各种metrics
def training_step(self, batch, batch_idx):
x, y = batch
prediction = self(x)
loss = nn.CrossEntropyLoss()(prediction,y)
return loss
#定义optimizer,以及可选的lr_scheduler
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return {"optimizer":optimizer}
def validation_step(self, batch, batch_idx):
loss = self.training_step(batch,batch_idx)
return {"val_loss":loss}
def test_step(self, batch, batch_idx):
loss = self.training_step(batch,batch_idx)
return {"test_loss":loss}
3,训练模型
pl.seed_everything(1234)
model = Model()
ckpt_callback = pl.callbacks.ModelCheckpoint(
monitor='val_loss',
save_top_k=1,
mode='min'
)
# gpus=0 则使用cpu训练,gpus=1则使用1个gpu训练,gpus=2则使用2个gpu训练,gpus=-1则使用所有gpu训练,
# gpus=[0,1]则指定使用0号和1号gpu训练, gpus="0,1,2,3"则使用0,1,2,3号gpu训练
# tpus=1 则使用1个tpu训练
trainer = pl.Trainer(max_epochs=5,gpus=0,callbacks = [ckpt_callback])
#断点续训
#trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')
trainer.fit(model,dl_train,dl_valid)
Global seed set to 1234
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
| Name | Type | Params
--------------------------------------
0 | layers | ModuleList | 54.0 K
--------------------------------------
54.0 K Trainable params
0 Non-trainable params
54.0 K Total params
Epoch 4: 100% >>>>>>>>>>>>>>>>>>>>>>>>>>>> 158/158 [00:19<00:00, 8.08it/s, loss=0.138, v_num=34]
4,评估模型
result = trainer.test(model, test_dataloaders=dl_valid)
print(result)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]
5,使用模型
data,label = next(iter(dl_valid))
model.eval()
prediction = model(data)
print(prediction)
tensor([[ -5.1149, -6.1142, 2.0591, ..., 7.0609, -5.4144, 0.5222],
[ -2.2989, -5.6076, 3.7343, ..., -1.8391, -6.4941, -3.4076],
[ 0.9215, 6.9357, -1.9887, ..., -2.2996, -0.8034, -3.2993],
...,
[ -4.5674, -6.0223, -0.9309, ..., -3.5468, 0.3367, 4.5473],
[ 4.3023, -4.1629, -1.2742, ..., -4.2527, -2.3449, -2.5585],
[ -3.8913, -10.3790, -1.7804, ..., -4.6757, -0.7428, 1.0305]],
grad_fn=<AddmmBackward>)
6,保存模型
最优模型默认保存在 trainer.checkpoint_callback.best_model_path 的目录下,可以直接加载。
print(trainer.checkpoint_callback.best_model_path)
print(trainer.checkpoint_callback.best_model_score)
/Users/liangyun/CodeFiles/PythonAiRoad/lightning_logs/version_34/checkpoints/epoch=04-val_loss=0.00.ckpt
tensor(0.0047)
model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer_clone = pl.Trainer(max_epochs=3)
result = trainer_clone.test(model_clone,dl_valid)
print(result)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]
如果对本文内容理解上有需要进一步和作者交流的地方,欢迎在公众号"算法美食屋"下留言。作者时间和精力有限,会酌情予以回复。
也可以在公众号后台回复关键字:加群,加入读者交流群和大家讨论。