基本操作 使用LightningModule 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class LitAutoEncoder (L.LightningModule ): def __init__ (self , encoder, decoder ): super ().__init__() self .encoder = encoder self .decoder = decoder def training_step (self , batch, batch_idx ): x, _ = batch x = x.view(x.size(0 ), -1 ) z = self .encoder(x) x_hat = self .decoder(z) loss = F.mse_loss(x_hat, x) return loss def configure_optimizers (self ): optimizer = torch.optim.Adam (self .parameters(), lr=1e-3 ) return optimizer
对设定的encoder和decoder自动执行下面的training_step,应包含batch导入,forward步骤,loss导入与计算。结束后自动使用optimizer。
1 2 3 4 5 6 autoencoder = LitAutoEncoder(Encoder(), Decoder()) trainer = L.Trainer() trainer.fit(model =autoencoder, train_dataloaders =train_loader)
定义好模型后使用trainer.fit进行训练,实际上后台执行的逻辑是:
1 2 3 4 5 6 7 8 9 autoencoder = LitAutoEncoder (Encoder (), Decoder ()) optimizer = autoencoder.configure_optimizers ()for batch_idx, batch in enumerate (train_loader): loss = autoencoder.training_step (batch, batch_idx) loss.backward () optimizer.step () optimizer.zero_grad ()
添加验证集及测试集 把一个数据集按预定比例拆分,分别作为训练集验证集。
1 2 3 4 5 6 7 train_set _size = int(len(train_set) * 0.8) valid_set _size = len(train_set) - train_set _size seed = torch.Generator() .manual_seed (42) train_set , valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator =seed)
在原模型中添加你需要的验证步骤,注意将验证集传入.fit使用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class LitAutoEncoder (L.LightningModule ): def training_step (self , batch, batch_idx ): ... def validation_step (self , batch, batch_idx ): x, _ = batch x = x.view(x.size(0 ), -1 ) z = self .encoder(x) x_hat = self .decoder(z) val_loss = F.mse_loss(x_hat, x) self .log("val_loss" , val_loss) trainer = L.Trainer () trainer.fit(model, train_loader, valid_loader)
测试集也一样,添加test_step步骤后,在.test中调用。
1 trainer.test(model, dataloaders =DataLoader(test_set))
保存/加载模型 Lightning 会自动在您当前工作目录中为您保存一个检查点,其中包含您上一个训练 epoch 的状态。这可确保在训练中断时可以恢复训练。更改路径可修改Trainer中的default_root_dir参数。通过以下方法恢复保存的参数:
1 model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt" )
早停 监控指标,在未观察到改善时停止训练。
1 2 3 4 5 6 7 8 9 10 11 from lightning.pytorch .callbacks .early_stopping import EarlyStopping class LitModel (LightningModule): def validation_step (self, batch, batch_idx): loss = ... self.log ("val_loss" , loss) model = LitModel () trainer = Trainer (callbacks=[EarlyStopping (monitor="val_loss" , mode="min" )]) trainer.fit (model)
您可以通过更改回调的参数来自定义回调行为。
1 2 early_stop_callback = EarlyStopping(monitor ="val_accuracy" , min_delta =0.00, patience =3, verbose =False , mode ="max" ) trainer = Trainer(callbacks=[early_stop_callback])
停止训练的其他参数:
stopping_threshold:一旦监测的数量达到此阈值,立即停止训练。 当我们知道超过某个最佳值不会进一步使我们受益时,它就很有用。
divergence_threshold:一旦监控的数量低于此阈值,就立即停止训练。 当达到如此糟糕的值时,我们认为模型无法再恢复,最好尽早停止并在不同的初始条件下运行。
check_finite:开启后,如果监控的指标变为 NaN 或 infinite,它会停止训练。
check_on_train_epoch_end:开启后,它会在训练 epoch 结束时检查指标。
预训练 对已经PATH训练过的模型,可以采取这种方式进行微调与新数据集的预测。
1 2 3 4 5 6 7 8 model = ImagenetTransferLearning() trainer = Trainer() trainer.fit(model) model = ImagenetTransferLearning.load_from_checkpoint(PATH ) model.freeze () x = some_images_from_cifar10() predictions = model(x)
调试模型 设置断点:在此示例中,代码将在执行y = x**2之前停止。
1 2 3 4 5 6 def function_to_debug(): x = 2 breakpoint () y = x**2
快速测试:此代码会运行5批训练、验证、测试数据,你也可以自定义数字。
1 trainer = Trainer(fast_dev_run= True)
缩短epoch:也可以用具体数字而非百分比。
1 trainer = Trainer(limit_train_batches=0 .1 , limit_val_batches=0 .01 )
发现训练瓶颈 advance暂时无法使用,simple可以打印每一步的用时,但是有点粗略可能没啥用。
可以检测下GPU/TPU/HPU容量,自行判断。
1 2 3 from lightning.pytorch .callbacks import DeviceStatsMonitor trainer = Trainer (callbacks=[DeviceStatsMonitor()] )
跟踪指标 在训练过程记录value指标:
1 2 3 def training_step (self , batch, batch_idx ): value = ... self .log("some_value" , value)
如果要记录多个指标,需要使用self.log_dict:
1 2 values = {"loss" : loss, "acc" : acc, "metric_n" : metric_n} # add more items if needed self.log_dict(values, prog_bar =True )
在命令行使用以下命令可以观察累计指标:
1 tensorboard --logdir =lightning_logs/