File size: 755 Bytes
ede298f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
from tqdm import tqdm
def train_epoch(model, loader, criterion, optimizer, device, scaler):
"""
Trains the model for a full epoch with mixed training (AMP).
"""
model.train()
total_loss = 0.0
pbar = tqdm(loader, desc="Training", leave=False)
for images, masks in pbar:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
logits = model(pixel_values=images).logits
loss = criterion(logits, masks)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
pbar.set_postfix(loss=f"{loss.item():.4f}")
return total_loss / len(loader)
|