Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
import numpy as np | |
import gzip | |
import os | |
from pathlib import Path | |
from datetime import datetime | |
import urllib.request | |
import shutil | |
from tqdm import tqdm | |
import asyncio | |
def download_and_extract_mnist_data(): | |
"""Download and extract MNIST dataset from a reliable mirror""" | |
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" | |
files = { | |
"train_images": "train-images-idx3-ubyte.gz", | |
"train_labels": "train-labels-idx1-ubyte.gz", | |
"test_images": "t10k-images-idx3-ubyte.gz", | |
"test_labels": "t10k-labels-idx1-ubyte.gz" | |
} | |
data_dir = Path("data/MNIST/raw") | |
data_dir.mkdir(parents=True, exist_ok=True) | |
for file_name in files.values(): | |
gz_file_path = data_dir / file_name | |
extracted_file_path = data_dir / file_name.replace('.gz', '') | |
# If the extracted file exists, skip downloading | |
if extracted_file_path.exists(): | |
print(f"{extracted_file_path} already exists, skipping download.") | |
continue | |
# Download the file | |
print(f"Downloading {file_name}...") | |
url = base_url + file_name | |
try: | |
urllib.request.urlretrieve(url, gz_file_path) | |
print(f"Successfully downloaded {file_name}") | |
except Exception as e: | |
print(f"Failed to download {file_name}: {e}") | |
raise Exception(f"Could not download {file_name}") | |
# Extract the files | |
try: | |
print(f"Extracting {file_name}...") | |
with gzip.open(gz_file_path, 'rb') as f_in: | |
with open(extracted_file_path, 'wb') as f_out: | |
shutil.copyfileobj(f_in, f_out) | |
print(f"Successfully extracted {file_name}") | |
except Exception as e: | |
print(f"Failed to extract {file_name}: {e}") | |
raise Exception(f"Could not extract {file_name}") | |
def load_mnist_images(filename): | |
with open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=16) | |
return data.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0 | |
def load_mnist_labels(filename): | |
with open(filename, 'rb') as f: | |
return np.frombuffer(f.read(), np.uint8, offset=8) | |
class CustomMNISTDataset(Dataset): | |
def __init__(self, images_path, labels_path, transform=None): | |
self.images = load_mnist_images(images_path) | |
self.labels = load_mnist_labels(labels_path) | |
self.transform = transform | |
def __len__(self): | |
return len(self.labels) | |
def __getitem__(self, idx): | |
image = torch.FloatTensor(self.images[idx]) | |
label = int(self.labels[idx]) | |
if self.transform: | |
image = self.transform(image) | |
return image, label | |
def validate(model, test_loader, criterion, device): | |
"""Modified validate function to handle validation properly""" | |
model.eval() | |
val_loss = 0 | |
correct = 0 | |
total = 0 | |
num_batches = 0 | |
with torch.no_grad(): # Important: no gradient computation in validation | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
val_loss += criterion(output, target).item() # Don't scale by batch size | |
_, predicted = output.max(1) | |
total += target.size(0) | |
correct += predicted.eq(target).sum().item() | |
num_batches += 1 | |
# Average the loss by number of batches and accuracy by total samples | |
val_loss = val_loss / num_batches # Average loss across batches | |
val_acc = 100. * correct / total | |
return val_loss, val_acc | |
async def train(model, config, websocket=None): | |
print("\nStarting training...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
model = model.to(device) | |
# Create data directory if it doesn't exist | |
data_dir = Path("data") | |
data_dir.mkdir(exist_ok=True) | |
# Ensure data is downloaded and extracted | |
print("Preparing dataset...") | |
download_and_extract_mnist_data() | |
# Paths to the extracted files | |
train_images_path = "data/MNIST/raw/train-images-idx3-ubyte" | |
train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte" | |
test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte" | |
test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte" | |
# Data loading | |
transform = transforms.Compose([ | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform) | |
test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) | |
print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}") | |
# Initialize optimizer based on config | |
if config.optimizer.lower() == 'adam': | |
optimizer = optim.Adam(model.parameters()) | |
else: | |
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) | |
criterion = nn.CrossEntropyLoss() | |
print("\nTraining Configuration:") | |
print(f"Optimizer: {config.optimizer}") | |
print(f"Batch Size: {config.batch_size}") | |
print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}") | |
print("\nStarting training loop...") | |
best_val_acc = 0 | |
history = { | |
'train_loss': [], | |
'train_acc': [], | |
'val_loss': [], | |
'val_acc': [] | |
} | |
try: | |
for epoch in range(config.epochs): | |
model.train() | |
total_loss = 0 | |
correct = 0 | |
total = 0 | |
# Create progress bar for each epoch | |
progress_bar = tqdm( | |
train_loader, | |
desc=f"Epoch {epoch+1}/{config.epochs}", | |
unit='batch', | |
leave=True | |
) | |
for batch_idx, (data, target) in enumerate(progress_bar): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
# Calculate batch accuracy | |
pred = output.argmax(dim=1, keepdim=True) | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
total += target.size(0) | |
total_loss += loss.item() | |
# Calculate current metrics | |
current_loss = total_loss / (batch_idx + 1) | |
current_acc = 100. * correct / total | |
# Update progress bar description | |
progress_bar.set_postfix({ | |
'loss': f'{current_loss:.4f}', | |
'acc': f'{current_acc:.2f}%' | |
}) | |
# Send training update through websocket | |
if websocket: | |
try: | |
await websocket.send_json({ | |
'type': 'training_update', | |
'data': { | |
'step': batch_idx + epoch * len(train_loader), | |
'train_loss': current_loss, | |
'train_acc': current_acc | |
} | |
}) | |
except Exception as e: | |
print(f"Error sending websocket update: {e}") | |
# Calculate epoch metrics | |
train_loss = total_loss / len(train_loader) | |
train_acc = 100. * correct / total | |
# Validation phase | |
model.eval() | |
val_loss = 0 | |
val_correct = 0 | |
val_total = 0 | |
print("\nRunning validation...") | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
val_loss += criterion(output, target).item() | |
pred = output.argmax(dim=1, keepdim=True) | |
val_correct += pred.eq(target.view_as(pred)).sum().item() | |
val_total += target.size(0) | |
val_loss /= len(test_loader) | |
val_acc = 100. * val_correct / val_total | |
# Print epoch results | |
print(f"\nEpoch {epoch+1}/{config.epochs} Results:") | |
print(f"Training Loss: {train_loss:.4f} | Training Accuracy: {train_acc:.2f}%") | |
print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%") | |
# Send validation update through websocket | |
if websocket: | |
try: | |
await websocket.send_json({ | |
'type': 'validation_update', | |
'data': { | |
'step': (epoch + 1) * len(train_loader), | |
'val_loss': val_loss, | |
'val_acc': val_acc | |
} | |
}) | |
except Exception as e: | |
print(f"Error sending websocket update: {e}") | |
# Save best model | |
if val_acc > best_val_acc: | |
best_val_acc = val_acc | |
print(f"\nNew best validation accuracy: {val_acc:.2f}%") | |
print("Saving model...") | |
torch.save(model.state_dict(), 'best_model.pth') | |
except Exception as e: | |
print(f"\nError during training: {e}") | |
raise e | |
print("\nTraining completed!") | |
print(f"Best validation accuracy: {best_val_acc:.2f}%") | |
return history | |