Spaces:
Sleeping
Sleeping
Update traininginVIT
Browse files- traininginVIT +23 -0
traininginVIT
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
from torch.utils.data import DataLoader, Dataset
|
2 |
import torch
|
3 |
from transformers import ViTForImageClassification, AdamW
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Custom dataset class for loading images
|
6 |
class MRIDataset(Dataset):
|
@@ -54,3 +59,21 @@ for epoch in range(num_epochs):
|
|
54 |
|
55 |
# Save the fine-tuned model
|
56 |
torch.save(model.state_dict(), 'vit_finetuned.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from torch.utils.data import DataLoader, Dataset
|
2 |
import torch
|
3 |
from transformers import ViTForImageClassification, AdamW
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import streamlit as st
|
8 |
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
9 |
|
10 |
# Custom dataset class for loading images
|
11 |
class MRIDataset(Dataset):
|
|
|
59 |
|
60 |
# Save the fine-tuned model
|
61 |
torch.save(model.state_dict(), 'vit_finetuned.pth')
|
62 |
+
|
63 |
+
def fine_tune_model():
|
64 |
+
# Your fine-tuning logic goes here (using the ViT model)
|
65 |
+
num_epochs = 10
|
66 |
+
running_loss = 0.0
|
67 |
+
for epoch in range(num_epochs):
|
68 |
+
# Fine-tuning loop (train the model)
|
69 |
+
# ...
|
70 |
+
running_loss += 0.5 # Just a placeholder for demo purposes
|
71 |
+
return running_loss # Return the final loss after training
|
72 |
+
|
73 |
+
# Streamlit UI to trigger fine-tuning and display results
|
74 |
+
st.title("MRI Image Fine-Tuning with ViT")
|
75 |
+
|
76 |
+
if st.button("Start Training"):
|
77 |
+
# Run the fine-tuning loop when the button is clicked
|
78 |
+
final_loss = fine_tune_model() # Call the function where your fine-tuning loop is
|
79 |
+
st.write(f"Training complete with final loss: {final_loss}")
|