Gate Recurrent Unit(GRU)是一种门控循环神经网络,由 Cho 等人在提出 Seq2Seq 模型时引入。
设计目标是:
在缓解传统 RNN 梯度消失问题的同时,保持比 LSTM 更简单的结构与更高的训练效率。
核心思想:
通过门控机制直接在隐藏状态上进行更新,不再区分“记忆单元”和“隐状态”。
GRU 相比 RNN / LSTM 的特点
相比普通 RNN
- 引入门控结构,缓解梯度消失
- 使用加权加法路径,稳定梯度传播
2026/1/30大约 3 分钟
Gate Recurrent Unit(GRU)是一种门控循环神经网络,由 Cho 等人在提出 Seq2Seq 模型时引入。
设计目标是:
在缓解传统 RNN 梯度消失问题的同时,保持比 LSTM 更简单的结构与更高的训练效率。
核心思想:
通过门控机制直接在隐藏状态上进行更新,不再区分“记忆单元”和“隐状态”。
Long Short-Term Memory(LSTM)是一种门控循环神经网络(Gated RNN),由 Hochreiter & Schmidhuber 提出,用于解决普通 RNN 在处理长序列依赖时的梯度消失 / 梯度爆炸问题。
核心思想:
通过显式的记忆单元(cell state)和门控机制,控制信息的写入、保留与输出。
普通 RNN:
解决问题:
y=x+F(x)
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler=None,
num_epochs=300, patience=20, model_save_path='models/best_model.pth'):
"""
通用的 PyTorch 模型训练函数,可选使用学习率调整器。
参数:
- model (nn.Module): 待训练的 PyTorch 模型。
- train_loader (DataLoader): 训练集的数据加载器。
- valid_loader (DataLoader): 验证集的数据加载器。
- criterion (nn.Module): 损失函数,例如 nn.MSELoss。
- optimizer (torch.optim.Optimizer): 优化器,例如 optim.Adam。
- scheduler (torch.optim.lr_scheduler): 可选的学习率调整器,默认为 None。
- num_epochs (int): 最大训练轮数,默认为 300。
- patience (int): 早停的耐心值,连续多少轮验证损失不下降后停止训练,默认为 20。
- model_save_path (str): 模型保存路径,默认为 'models/best_model.pth'。
返回:
- dict: 包含训练和验证损失的字典,格式为 {'train_losses': [...], 'valid_losses': [...]}。
"""
best_val_loss = float('inf')
early_stopping_counter = 0
train_losses, valid_losses = [], []
start_time = time.time()
for epoch in range(num_epochs):
epoch_start_time = time.time()
# 训练模式
model.train()
train_loss = 0.0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
train_losses.append(train_loss)
# 验证模式
model.eval()
valid_loss = 0.0
with torch.no_grad():
for X_batch, y_batch in valid_loader:
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
valid_loss += loss.item()
valid_loss /= len(valid_loader)
valid_losses.append(valid_loss)
# 调整学习率(如果有设置)
if scheduler:
scheduler.step(valid_loss)
epoch_end_time = time.time()
epoch_time = epoch_end_time - epoch_start_time
print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Time: {epoch_time:.2f}s")
# 检查验证损失是否降低
if valid_loss < best_val_loss:
best_val_loss = valid_loss
torch.save(model.state_dict(), model_save_path)
early_stopping_counter = 0
else:
early_stopping_counter += 1
if early_stopping_counter >= patience:
print("Early stopping triggered")
break
total_time = time.time() - start_time
print(f"Training completed. Best validation loss: {best_val_loss:.4f}")
print(f"Total training time: {total_time:.2f}s")
return {'train_losses': train_losses, 'valid_losses': valid_losses}