TatsukichiHayama / Train.py
Yuzmi's picture
Create Train.py
27010fd
raw
history blame
1.12 kB
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
# Define your model class
class TatsukichiHayamaClassifier(nn.Module):
# ... (your model definition)
# Load dataset from PyTorch's ImageFolder
train_dataset = datasets.ImageFolder(root="TatsukichiHayamaDataset", transform=transforms.ToTensor())
# Create a DataLoader for training
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
# Create an instance of TatsukichiHayamaClassifier
your_num_classes = 10 # Adjust this based on your dataset
model = TatsukichiHayamaClassifier(num_classes=your_num_classes)
# Model, criterion, and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')