PyTorch
Souha-BH commited on
Commit
73ca82a
·
verified ·
1 Parent(s): 07128b2

Create example_usage.py

Browse files
Files changed (1) hide show
  1. example_usage.py +206 -0
example_usage.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
3
+ from torchvision import models, transforms
4
+ import torch.nn as nn
5
+ import os
6
+ import json
7
+ import cv2
8
+ from PIL import Image
9
+ import gradio as gr
10
+
11
+ class MultimodalRiskBehaviorModel(nn.Module):
12
+ def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
13
+ super(MultimodalRiskBehaviorModel, self).__init__()
14
+
15
+ # Text model using AutoModelForSequenceClassification
16
+ self.text_model_name = text_model_name
17
+ self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
18
+
19
+ # Visual model (ResNet50)
20
+ self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
21
+ visual_feature_dim = self.visual_model.fc.in_features
22
+ self.visual_model.fc = nn.Identity()
23
+
24
+ # Fusion and classification layer setup
25
+ text_feature_dim = self.text_model.config.hidden_size
26
+ self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
27
+ self.dropout = nn.Dropout(dropout)
28
+ self.fc2 = nn.Linear(hidden_dim, 1)
29
+
30
+ def forward(self, encoding, frames):
31
+ input_ids = encoding['input_ids'].squeeze(1).to(device)
32
+ attention_mask = encoding['attention_mask'].squeeze(1).to(device)
33
+
34
+ # Extract text and visual features
35
+ text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
36
+ frames = frames.to(device)
37
+
38
+ batch_size, num_frames, channels, height, width = frames.size()
39
+ frames = frames.view(batch_size * num_frames, channels, height, width)
40
+ visual_features = self.visual_model(frames)
41
+ visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
42
+
43
+ # Combine and classify
44
+ combined_features = torch.cat((text_features, visual_features), dim=1)
45
+ x = self.dropout(torch.relu(self.fc1(combined_features)))
46
+ output = torch.sigmoid(self.fc2(x))
47
+
48
+ return output
49
+
50
+ def save_pretrained(self, save_directory):
51
+ os.makedirs(save_directory, exist_ok=True)
52
+ torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
53
+ config = {
54
+ "text_model_name": self.text_model_name,
55
+ "hidden_dim": self.fc1.out_features
56
+ }
57
+ with open(os.path.join(save_directory, 'config.json'), 'w') as f:
58
+ json.dump(config, f)
59
+
60
+ @classmethod
61
+ def from_pretrained(cls, load_directory, map_location=None):
62
+ if os.path.exists(load_directory):
63
+ config_path = os.path.join(load_directory, 'config.json')
64
+ state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
65
+
66
+ with open(config_path, 'r') as f:
67
+ config_dict = json.load(f)
68
+ model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
69
+ state_dict = torch.load(state_dict_path, map_location=map_location)
70
+ model.load_state_dict(state_dict)
71
+
72
+ else:
73
+ hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2)
74
+ model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size)
75
+ model.text_model = hf_model
76
+
77
+ return model
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
80
+ model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
81
+
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ model.to(device)
84
+
85
+
86
+
87
+ # Function to load frames from a video
88
+ def load_frames_from_video(video_path, transform, num_frames=10):
89
+ cap = cv2.VideoCapture(video_path)
90
+ frames = []
91
+ frame_count = 0
92
+ while frame_count < num_frames: # Limit to a number of frames for efficiency
93
+ success, frame = cap.read()
94
+ if not success:
95
+ break
96
+ # Convert frame (NumPy array) to PIL image and apply transformations
97
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
98
+ frame = transform(frame)
99
+ frames.append(frame)
100
+ frame_count += 1
101
+ cap.release()
102
+
103
+ # Stack frames and add batch dimension (1, num_frames, channels, height, width)
104
+ frames = torch.stack(frames)
105
+ frames = frames.unsqueeze(0) # Add batch dimension
106
+ return frames
107
+
108
+ def predict_video(model, video_path, text_input, tokenizer, transform):
109
+ try:
110
+ # Set model to evaluation mode
111
+ model.eval()
112
+
113
+ # Tokenize the text input
114
+ encoding = tokenizer(
115
+ text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
116
+ )
117
+ encoding = {key: val.to(device) for key, val in encoding.items()}
118
+
119
+ # Load frames from the video
120
+ frames = load_frames_from_video(video_path, transform)
121
+ frames = frames.to(device)
122
+
123
+ # Log input shapes and devices
124
+ print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}")
125
+
126
+ # Perform forward pass through the model
127
+ with torch.no_grad():
128
+ output = model(encoding, frames)
129
+
130
+ # Apply sigmoid to get probability, then threshold to get prediction
131
+ prediction = (output.squeeze(-1) > 0.5).float()
132
+
133
+ return prediction.item()
134
+
135
+ except Exception as e:
136
+ print(f"Prediction error: {e}")
137
+ return "Error during prediction"
138
+
139
+
140
+
141
+
142
+ transform = transforms.Compose([
143
+ transforms.Resize((224, 224)),
144
+ transforms.ToTensor(),
145
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
146
+ ])
147
+
148
+
149
+ # Define your video paths and captions
150
+ video_paths = [
151
+ 'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
152
+ 'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
153
+ 'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
154
+ ]
155
+
156
+ video_captions = [
157
+ "Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok",
158
+ "New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
159
+ "all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
160
+ ]
161
+
162
+
163
+ def predict_risk(video_index):
164
+ video_path = video_paths[video_index]
165
+ text_input = video_captions[video_index]
166
+
167
+ # Make prediction
168
+ prediction = predict_video(model, video_path, text_input, tokenizer, transform)
169
+
170
+ # Return the corresponding label
171
+ return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
172
+
173
+ # Interface setup
174
+ with gr.Blocks() as interface:
175
+ gr.Markdown("# Risk Behavior Prediction")
176
+ gr.Markdown("Select a video to classify its behavior as risky or not.")
177
+
178
+ # Input option selector
179
+ video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
180
+
181
+ # Use function to return URLs which are handled by the Gradio `gr.Video` component
182
+ def show_selected_video(choice):
183
+ idx = int(choice.split()[-1]) - 1
184
+ return video_paths[idx], f"**Caption:** {video_captions[idx]}"
185
+
186
+ video_player = gr.Video(width=320, height=240)
187
+ caption_box = gr.Markdown()
188
+
189
+ video_selector.change(
190
+ fn=show_selected_video,
191
+ inputs=video_selector,
192
+ outputs=[video_player, caption_box]
193
+ )
194
+
195
+ # Prediction button and output
196
+ predict_button = gr.Button("Predict Risk")
197
+ output_text = gr.Textbox(label="Prediction")
198
+
199
+ predict_button.click(
200
+ fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
201
+ inputs=video_selector,
202
+ outputs=output_text
203
+ )
204
+
205
+ # Launch the app
206
+ interface.launch()