必备代码
2025/9/21约 375 字大约 1 分钟
必备代码
训练函数模板代码
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler=None,
num_epochs=300, patience=20, model_save_path='models/best_model.pth'):
"""
通用的 PyTorch 模型训练函数,可选使用学习率调整器。
参数:
- model (nn.Module): 待训练的 PyTorch 模型。
- train_loader (DataLoader): 训练集的数据加载器。
- valid_loader (DataLoader): 验证集的数据加载器。
- criterion (nn.Module): 损失函数,例如 nn.MSELoss。
- optimizer (torch.optim.Optimizer): 优化器,例如 optim.Adam。
- scheduler (torch.optim.lr_scheduler): 可选的学习率调整器,默认为 None。
- num_epochs (int): 最大训练轮数,默认为 300。
- patience (int): 早停的耐心值,连续多少轮验证损失不下降后停止训练,默认为 20。
- model_save_path (str): 模型保存路径,默认为 'models/best_model.pth'。
返回:
- dict: 包含训练和验证损失的字典,格式为 {'train_losses': [...], 'valid_losses': [...]}。
"""
best_val_loss = float('inf')
early_stopping_counter = 0
train_losses, valid_losses = [], []
start_time = time.time()
for epoch in range(num_epochs):
epoch_start_time = time.time()
# 训练模式
model.train()
train_loss = 0.0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
train_losses.append(train_loss)
# 验证模式
model.eval()
valid_loss = 0.0
with torch.no_grad():
for X_batch, y_batch in valid_loader:
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
valid_loss += loss.item()
valid_loss /= len(valid_loader)
valid_losses.append(valid_loss)
# 调整学习率(如果有设置)
if scheduler:
scheduler.step(valid_loss)
epoch_end_time = time.time()
epoch_time = epoch_end_time - epoch_start_time
print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Time: {epoch_time:.2f}s")
# 检查验证损失是否降低
if valid_loss < best_val_loss:
best_val_loss = valid_loss
torch.save(model.state_dict(), model_save_path)
early_stopping_counter = 0
else:
early_stopping_counter += 1
if early_stopping_counter >= patience:
print("Early stopping triggered")
break
total_time = time.time() - start_time
print(f"Training completed. Best validation loss: {best_val_loss:.4f}")
print(f"Total training time: {total_time:.2f}s")
return {'train_losses': train_losses, 'valid_losses': valid_losses}