Pytorch lightning学习

基本操作

使用LightningModule

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def training_step(self, batch, batch_idx):
# training_step defines the train loop.
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

对设定的encoder和decoder自动执行下面的training_step,应包含batch导入,forward步骤,loss导入与计算。结束后自动使用optimizer。

1
2
3
4
5
6
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# train model
trainer = L.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

定义好模型后使用trainer.fit进行训练,实际上后台执行的逻辑是:

1
2
3
4
5
6
7
8
9
autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()

for batch_idx, batch in enumerate(train_loader):
loss = autoencoder.training_step(batch, batch_idx)

loss.backward()
optimizer.step()
optimizer.zero_grad()

添加验证集及测试集

把一个数据集按预定比例拆分,分别作为训练集验证集。

1
2
3
4
5
6
7
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

在原模型中添加你需要的验证步骤,注意将验证集传入.fit使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LitAutoEncoder(L.LightningModule):
def training_step(self, batch, batch_idx):
...

def validation_step(self, batch, batch_idx):
# this is the validation loop
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)

trainer = L.Trainer()
trainer.fit(model, train_loader, valid_loader)

测试集也一样,添加test_step步骤后,在.test中调用。

1
trainer.test(model, dataloaders=DataLoader(test_set))

保存/加载模型

Lightning 会自动在您当前工作目录中为您保存一个检查点,其中包含您上一个训练 epoch 的状态。这可确保在训练中断时可以恢复训练。更改路径可修改Trainer中的default_root_dir参数。通过以下方法恢复保存的参数:

1
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")

早停

监控指标,在未观察到改善时停止训练。

1
2
3
4
5
6
7
8
9
10
11
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
self.log("val_loss", loss)


model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

您可以通过更改回调的参数来自定义回调行为。

1
2
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])

停止训练的其他参数:

  • stopping_threshold:一旦监测的数量达到此阈值,立即停止训练。 当我们知道超过某个最佳值不会进一步使我们受益时,它就很有用。
  • divergence_threshold:一旦监控的数量低于此阈值,就立即停止训练。 当达到如此糟糕的值时,我们认为模型无法再恢复,最好尽早停止并在不同的初始条件下运行。
  • check_finite:开启后,如果监控的指标变为 NaN 或 infinite,它会停止训练。
  • check_on_train_epoch_end:开启后,它会在训练 epoch 结束时检查指标。

预训练

对已经PATH训练过的模型,可以采取这种方式进行微调与新数据集的预测。

1
2
3
4
5
6
7
8
model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

调试模型

设置断点:在此示例中,代码将在执行y = x**2之前停止。

1
2
3
4
5
6
def function_to_debug():
x = 2

# set breakpoint
breakpoint()
y = x**2

快速测试:此代码会运行5批训练、验证、测试数据,你也可以自定义数字。

1
trainer = Trainer(fast_dev_run=True)

缩短epoch:也可以用具体数字而非百分比。

1
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)

发现训练瓶颈

advance暂时无法使用,simple可以打印每一步的用时,但是有点粗略可能没啥用。

1
profiler="simple"

可以检测下GPU/TPU/HPU容量,自行判断。

1
2
3
from lightning.pytorch.callbacks import DeviceStatsMonitor

trainer = Trainer(callbacks=[DeviceStatsMonitor()])

跟踪指标

在训练过程记录value指标:

1
2
3
def training_step(self, batch, batch_idx):
value = ...
self.log("some_value", value)

如果要记录多个指标,需要使用self.log_dict:

1
2
values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values, prog_bar=True)

在命令行使用以下命令可以观察累计指标:

1
tensorboard --logdir=lightning_logs/

image-20250717202418931


Pytorch lightning学习
http://pleinelune-r.github.io/2025/08/05/Pytorch lightning学习/
作者
Pleinelune
发布于
2025年8月5日
许可协议