Dileep7729 commited on
Commit
514b8b1
·
verified ·
1 Parent(s): 352c39d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import torch
4
+ from torch import nn, optim
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ from transformers import CLIPModel, CLIPProcessor
9
+ import gradio as gr
10
+
11
+ # Step 1: Unzip the dataset
12
+ if not os.path.exists("data"):
13
+ os.makedirs("data")
14
+
15
+ print("Extracting Data.zip...")
16
+ with zipfile.ZipFile("Data.zip", 'r') as zip_ref:
17
+ zip_ref.extractall("data")
18
+ print("Extraction complete.")
19
+
20
+ # Step 2: Dynamically find the 'safe' and 'unsafe' folders
21
+ def find_dataset_path(root_dir):
22
+ for root, dirs, files in os.walk(root_dir):
23
+ if 'safe' in dirs and 'unsafe' in dirs:
24
+ return root
25
+ return None
26
+
27
+ # Look for 'safe' and 'unsafe' inside 'data/Data'
28
+ dataset_path = find_dataset_path("data/Data")
29
+ if dataset_path is None:
30
+ print("Debugging extracted structure:")
31
+ for root, dirs, files in os.walk("data"):
32
+ print(f"Root: {root}")
33
+ print(f"Directories: {dirs}")
34
+ print(f"Files: {files}")
35
+ raise FileNotFoundError("Expected 'safe' and 'unsafe' folders not found inside 'data/Data'. Please check the Data.zip structure.")
36
+ print(f"Dataset path found: {dataset_path}")
37
+
38
+ # Step 3: Define Custom Dataset Class
39
+ class CustomImageDataset(Dataset):
40
+ def __init__(self, root_dir, transform=None):
41
+ self.root_dir = root_dir
42
+ self.transform = transform
43
+ self.image_paths = []
44
+ self.labels = []
45
+
46
+ for label, folder in enumerate(["safe", "unsafe"]): # 0 = safe, 1 = unsafe
47
+ folder_path = os.path.join(root_dir, folder)
48
+ if not os.path.exists(folder_path):
49
+ raise FileNotFoundError(f"Folder '{folder}' not found in '{root_dir}'")
50
+ for filename in os.listdir(folder_path):
51
+ if filename.endswith((".jpg", ".jpeg", ".png")): # Only load image files
52
+ self.image_paths.append(os.path.join(folder_path, filename))
53
+ self.labels.append(label)
54
+
55
+ def __len__(self):
56
+ return len(self.image_paths)
57
+
58
+ def __getitem__(self, idx):
59
+ image_path = self.image_paths[idx]
60
+ image = Image.open(image_path).convert("RGB")
61
+ label = self.labels[idx]
62
+ if self.transform:
63
+ image = self.transform(image)
64
+ return image, label
65
+
66
+ # Step 4: Data Transformations
67
+ transform = transforms.Compose([
68
+ transforms.Resize((224, 224)), # Resize to 224x224 pixels
69
+ transforms.ToTensor(), # Convert to tensor
70
+ transforms.Normalize((0.5,), (0.5,)), # Normalize image values
71
+ ])
72
+
73
+ # Step 5: Load the Dataset
74
+ train_dataset = CustomImageDataset(dataset_path, transform=transform)
75
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
76
+
77
+ # Debugging: Check the dataset
78
+ print(f"Number of samples in the dataset: {len(train_dataset)}")
79
+ if len(train_dataset) == 0:
80
+ raise ValueError("The dataset is empty. Please check if 'Data.zip' is correctly unzipped and contains 'safe' and 'unsafe' folders.")
81
+
82
+ # Step 6: Load Pretrained CLIP Model
83
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
84
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
85
+
86
+ # Add a Classification Layer
87
+ model.classifier = nn.Linear(model.visual_projection.out_features, 2) # 2 classes: safe, unsafe
88
+
89
+ # Define Optimizer and Loss Function
90
+ optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4)
91
+ criterion = nn.CrossEntropyLoss()
92
+
93
+ # Step 7: Fine-Tune the Model
94
+ model.train()
95
+ for epoch in range(3): # Number of epochs
96
+ total_loss = 0
97
+ for images, labels in train_loader:
98
+ optimizer.zero_grad()
99
+ images = torch.stack([img.to(torch.float32) for img in images]) # Batch of images
100
+ outputs = model.get_image_features(pixel_values=images)
101
+ logits = model.classifier(outputs)
102
+ loss = criterion(logits, labels)
103
+ loss.backward()
104
+ optimizer.step()
105
+ total_loss += loss.item()
106
+ print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
107
+
108
+ # Save the Fine-Tuned Model
109
+ model.save_pretrained("fine-tuned-model")
110
+ processor.save_pretrained("fine-tuned-model")
111
+ print("Model fine-tuned and saved successfully.")
112
+
113
+ # Step 8: Define Gradio Inference Function
114
+ def classify_image(image, class_names):
115
+ # Load Fine-Tuned Model
116
+ model = CLIPModel.from_pretrained("fine-tuned-model")
117
+ processor = CLIPProcessor.from_pretrained("fine-tuned-model")
118
+
119
+ # Split class names from comma-separated input
120
+ labels = [label.strip() for label in class_names.split(",") if label.strip()]
121
+ if not labels:
122
+ return {"Error": "Please enter at least one valid class name."}
123
+
124
+ # Process the image and labels
125
+ inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
126
+ outputs = model(**inputs)
127
+ logits_per_image = outputs.logits_per_image
128
+ probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
129
+
130
+ # Extract labels with their corresponding probabilities
131
+ result = {label: probs[0][i].item() for i, label in enumerate(labels)}
132
+ return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
133
+
134
+ # Step 9: Set Up Gradio Interface
135
+ iface = gr.Interface(
136
+ fn=classify_image,
137
+ inputs=[
138
+ gr.Image(type="pil"),
139
+ gr.Textbox(label="Possible class names (comma-separated)", placeholder="e.g., safe, unsafe")
140
+ ],
141
+ outputs=gr.Label(num_top_classes=2),
142
+ title="Content Safety Classification",
143
+ description="Classify images as 'safe' or 'unsafe' using a fine-tuned CLIP model.",
144
+ )
145
+
146
+ # Launch Gradio Interface
147
+ if __name__ == "__main__":
148
+ iface.launch()
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+