Tanusree88 commited on
Commit
03aba46
·
verified ·
1 Parent(s): 9594f1f

Update traininginVIT

Browse files
Files changed (1) hide show
  1. traininginVIT +23 -0
traininginVIT CHANGED
@@ -1,6 +1,11 @@
1
  from torch.utils.data import DataLoader, Dataset
2
  import torch
3
  from transformers import ViTForImageClassification, AdamW
 
 
 
 
 
4
 
5
  # Custom dataset class for loading images
6
  class MRIDataset(Dataset):
@@ -54,3 +59,21 @@ for epoch in range(num_epochs):
54
 
55
  # Save the fine-tuned model
56
  torch.save(model.state_dict(), 'vit_finetuned.pth')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from torch.utils.data import DataLoader, Dataset
2
  import torch
3
  from transformers import ViTForImageClassification, AdamW
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import streamlit as st
8
+ from transformers import ViTForImageClassification, ViTImageProcessor
9
 
10
  # Custom dataset class for loading images
11
  class MRIDataset(Dataset):
 
59
 
60
  # Save the fine-tuned model
61
  torch.save(model.state_dict(), 'vit_finetuned.pth')
62
+
63
+ def fine_tune_model():
64
+ # Your fine-tuning logic goes here (using the ViT model)
65
+ num_epochs = 10
66
+ running_loss = 0.0
67
+ for epoch in range(num_epochs):
68
+ # Fine-tuning loop (train the model)
69
+ # ...
70
+ running_loss += 0.5 # Just a placeholder for demo purposes
71
+ return running_loss # Return the final loss after training
72
+
73
+ # Streamlit UI to trigger fine-tuning and display results
74
+ st.title("MRI Image Fine-Tuning with ViT")
75
+
76
+ if st.button("Start Training"):
77
+ # Run the fine-tuning loop when the button is clicked
78
+ final_loss = fine_tune_model() # Call the function where your fine-tuning loop is
79
+ st.write(f"Training complete with final loss: {final_loss}")