benjaminStreltzin commited on
Commit
eb6241e
·
verified ·
1 Parent(s): ae6c854

Create Vit_Traning.py

Browse files
Files changed (1) hide show
  1. Vit_Traning.py +145 -0
Vit_Traning.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ from transformers import ViTForImageClassification
6
+ from PIL import Image
7
+ import torch.optim as optim
8
+ import os
9
+ import pandas as pd
10
+ from sklearn.model_selection import train_test_split
11
+
12
+ def labeling(path_real, path_fake):
13
+ image_paths = []
14
+ labels = []
15
+
16
+ for filename in os.listdir(path_real):
17
+ image_paths.append(os.path.join(path_real, filename))
18
+ labels.append(0)
19
+
20
+ for filename in os.listdir(path_fake):
21
+ image_paths.append(os.path.join(path_fake, filename))
22
+ labels.append(1)
23
+
24
+ dataset = pd.DataFrame({'image_path': image_paths, 'label': labels})
25
+
26
+ return dataset
27
+
28
+ class CustomDataset(Dataset):
29
+ def __init__(self, dataframe, transform=None):
30
+ self.dataframe = dataframe
31
+ self.transform = transform
32
+
33
+ def __len__(self):
34
+ return len(self.dataframe)
35
+
36
+ def __getitem__(self, idx):
37
+ image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
38
+ image = Image.open(image_path).convert('RGB') # Convert to RGB format
39
+
40
+ if self.transform:
41
+ image = self.transform(image)
42
+
43
+ label = self.dataframe.iloc[idx, 1] # Label is in the second column
44
+ return image, label
45
+
46
+ def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
47
+ shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
48
+ train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
49
+ return train_df, val_df
50
+
51
+ class CustomModel:
52
+ def __init__(self):
53
+ # Check for GPU availability
54
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55
+
56
+ # Load the pre-trained ViT model and move it to the device
57
+ self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
58
+
59
+ # Freeze pre-trained layers
60
+ for param in self.model.parameters():
61
+ param.requires_grad = False
62
+
63
+ # Define a new classifier and move it to the device
64
+ self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
65
+
66
+ # Define the optimizer
67
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
68
+
69
+ # Define the image preprocessing pipeline
70
+ self.preprocess = transforms.Compose([
71
+ transforms.Resize((224, 224)),
72
+ transforms.ToTensor()
73
+ ])
74
+
75
+ # Initialize DataFrame for user data
76
+ self.data_file = 'user_data.csv'
77
+ if os.path.exists(self.data_file):
78
+ self.df = pd.read_csv(self.data_file)
79
+ else:
80
+ self.df = pd.DataFrame(columns=['image_path', 'label'])
81
+
82
+ def add_data(self, image_path: str, label: int):
83
+ new_entry = pd.DataFrame({'image_path': [image_path], 'label': [label]})
84
+ self.df = pd.concat([self.df, new_entry], ignore_index=True)
85
+ self.df.to_csv(self.data_file, index=False)
86
+
87
+ # Check if we have 100 images for retraining
88
+ if len(self.df) >= 100:
89
+ self.retrain_model()
90
+
91
+ def retrain_model(self):
92
+ # Shuffle and split the data
93
+ train_df, val_df = shuffle_and_split_data(self.df)
94
+
95
+ # Define the dataset and dataloaders
96
+ train_dataset = CustomDataset(train_df, transform=self.preprocess)
97
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
98
+
99
+ val_dataset = CustomDataset(val_df, transform=self.preprocess)
100
+ val_loader = DataLoader(val_dataset, batch_size=32)
101
+
102
+ # Define the loss function
103
+ criterion = nn.CrossEntropyLoss().to(self.device)
104
+
105
+ # Training loop
106
+ num_epochs = 10
107
+ for epoch in range(num_epochs):
108
+ self.model.train()
109
+ running_loss = 0.0
110
+ for images, labels in train_loader:
111
+ images, labels = images.to(self.device), labels.to(self.device)
112
+
113
+ self.optimizer.zero_grad()
114
+ outputs = self.model(images)
115
+ logits = outputs.logits # Extract logits from the output
116
+ loss = criterion(logits, labels)
117
+ loss.backward()
118
+ self.optimizer.step()
119
+ running_loss += loss.item()
120
+ print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
121
+
122
+ # Validation loop
123
+ self.model.eval()
124
+ correct = 0
125
+ total = 0
126
+ with torch.no_grad():
127
+ for images, labels in val_loader:
128
+ images, labels = images.to(self.device), labels.to(self.device)
129
+ outputs = self.model(images)
130
+ logits = outputs.logits
131
+ _, predicted = torch.max(logits, 1)
132
+ total += labels.size(0)
133
+ correct += (predicted == labels).sum().item()
134
+ print(f"Validation Accuracy: {correct / total}")
135
+
136
+ # Save the retrained model
137
+ torch.save(self.model.state_dict(), 'trained_model.pth')
138
+ print("Model retrained and updated!")
139
+
140
+ if __name__ == "__main__":
141
+ # Initialize the model
142
+ custom_model = CustomModel()
143
+
144
+ # Example usage: adding a new image and label
145
+ # custom_model.add_data('path/to/image.jpg', 0) # 0 for real, 1 for fake