Spaces:
Sleeping
Sleeping
from torch.utils.data import DataLoader, Dataset | |
import torch | |
from transformers import ViTForImageClassification, AdamW | |
import os | |
import numpy as np | |
import torch | |
import streamlit as st | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
# Custom dataset class for loading images | |
class MRIDataset(Dataset): | |
def __init__(self, image_paths, labels): | |
self.image_paths = image_paths | |
self.labels = labels | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image = preprocess_image(self.image_paths[idx]) | |
label = torch.tensor(self.labels[idx]) | |
return image, label | |
# Load your ViT model and processor | |
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3) | |
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
# Move the model to the device (GPU if available) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
# Define optimizer and loss function | |
optimizer = AdamW(model.parameters(), lr=1e-4) | |
criterion = torch.nn.CrossEntropyLoss() | |
# Load your dataset | |
image_paths = ["path_to_image1.npy", "path_to_image2.npy"] # Update with actual image paths | |
labels = [0, 1] # Corresponding labels | |
dataset = MRIDataset(image_paths, labels) | |
data_loader = DataLoader(dataset, batch_size=16, shuffle=True) | |
# Fine-tuning loop | |
num_epochs = 10 | |
for epoch in range(num_epochs): | |
model.train() | |
total_loss = 0 | |
for images, labels in data_loader: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(pixel_values=images).logits | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}') | |
# Save the fine-tuned model | |
torch.save(model.state_dict(), 'vit_finetuned.pth') | |
def fine_tune_model(): | |
# Your fine-tuning logic goes here (using the ViT model) | |
num_epochs = 10 | |
running_loss = 0.0 | |
for epoch in range(num_epochs): | |
# Fine-tuning loop (train the model) | |
# ... | |
running_loss += 0.5 # Just a placeholder for demo purposes | |
return running_loss # Return the final loss after training | |
# Streamlit UI to trigger fine-tuning and display results | |
st.title("MRI Image Fine-Tuning with ViT") | |
if st.button("Start Training"): | |
# Run the fine-tuning loop when the button is clicked | |
final_loss = fine_tune_model() # Call the function where your fine-tuning loop is | |
st.write(f"Training complete with final loss: {final_loss}") | |