Create example_usage.py
Browse files- example_usage.py +206 -0
example_usage.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
3 |
+
from torchvision import models, transforms
|
4 |
+
import torch.nn as nn
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import cv2
|
8 |
+
from PIL import Image
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
class MultimodalRiskBehaviorModel(nn.Module):
|
12 |
+
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
|
13 |
+
super(MultimodalRiskBehaviorModel, self).__init__()
|
14 |
+
|
15 |
+
# Text model using AutoModelForSequenceClassification
|
16 |
+
self.text_model_name = text_model_name
|
17 |
+
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=1)
|
18 |
+
|
19 |
+
# Visual model (ResNet50)
|
20 |
+
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
21 |
+
visual_feature_dim = self.visual_model.fc.in_features
|
22 |
+
self.visual_model.fc = nn.Identity()
|
23 |
+
|
24 |
+
# Fusion and classification layer setup
|
25 |
+
text_feature_dim = self.text_model.config.hidden_size
|
26 |
+
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
|
27 |
+
self.dropout = nn.Dropout(dropout)
|
28 |
+
self.fc2 = nn.Linear(hidden_dim, 1)
|
29 |
+
|
30 |
+
def forward(self, encoding, frames):
|
31 |
+
input_ids = encoding['input_ids'].squeeze(1).to(device)
|
32 |
+
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
|
33 |
+
|
34 |
+
# Extract text and visual features
|
35 |
+
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
36 |
+
frames = frames.to(device)
|
37 |
+
|
38 |
+
batch_size, num_frames, channels, height, width = frames.size()
|
39 |
+
frames = frames.view(batch_size * num_frames, channels, height, width)
|
40 |
+
visual_features = self.visual_model(frames)
|
41 |
+
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
|
42 |
+
|
43 |
+
# Combine and classify
|
44 |
+
combined_features = torch.cat((text_features, visual_features), dim=1)
|
45 |
+
x = self.dropout(torch.relu(self.fc1(combined_features)))
|
46 |
+
output = torch.sigmoid(self.fc2(x))
|
47 |
+
|
48 |
+
return output
|
49 |
+
|
50 |
+
def save_pretrained(self, save_directory):
|
51 |
+
os.makedirs(save_directory, exist_ok=True)
|
52 |
+
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
|
53 |
+
config = {
|
54 |
+
"text_model_name": self.text_model_name,
|
55 |
+
"hidden_dim": self.fc1.out_features
|
56 |
+
}
|
57 |
+
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
|
58 |
+
json.dump(config, f)
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def from_pretrained(cls, load_directory, map_location=None):
|
62 |
+
if os.path.exists(load_directory):
|
63 |
+
config_path = os.path.join(load_directory, 'config.json')
|
64 |
+
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
|
65 |
+
|
66 |
+
with open(config_path, 'r') as f:
|
67 |
+
config_dict = json.load(f)
|
68 |
+
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
|
69 |
+
state_dict = torch.load(state_dict_path, map_location=map_location)
|
70 |
+
model.load_state_dict(state_dict)
|
71 |
+
|
72 |
+
else:
|
73 |
+
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2)
|
74 |
+
model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size)
|
75 |
+
model.text_model = hf_model
|
76 |
+
|
77 |
+
return model
|
78 |
+
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
|
80 |
+
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
|
81 |
+
|
82 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
83 |
+
model.to(device)
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
# Function to load frames from a video
|
88 |
+
def load_frames_from_video(video_path, transform, num_frames=10):
|
89 |
+
cap = cv2.VideoCapture(video_path)
|
90 |
+
frames = []
|
91 |
+
frame_count = 0
|
92 |
+
while frame_count < num_frames: # Limit to a number of frames for efficiency
|
93 |
+
success, frame = cap.read()
|
94 |
+
if not success:
|
95 |
+
break
|
96 |
+
# Convert frame (NumPy array) to PIL image and apply transformations
|
97 |
+
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
98 |
+
frame = transform(frame)
|
99 |
+
frames.append(frame)
|
100 |
+
frame_count += 1
|
101 |
+
cap.release()
|
102 |
+
|
103 |
+
# Stack frames and add batch dimension (1, num_frames, channels, height, width)
|
104 |
+
frames = torch.stack(frames)
|
105 |
+
frames = frames.unsqueeze(0) # Add batch dimension
|
106 |
+
return frames
|
107 |
+
|
108 |
+
def predict_video(model, video_path, text_input, tokenizer, transform):
|
109 |
+
try:
|
110 |
+
# Set model to evaluation mode
|
111 |
+
model.eval()
|
112 |
+
|
113 |
+
# Tokenize the text input
|
114 |
+
encoding = tokenizer(
|
115 |
+
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
|
116 |
+
)
|
117 |
+
encoding = {key: val.to(device) for key, val in encoding.items()}
|
118 |
+
|
119 |
+
# Load frames from the video
|
120 |
+
frames = load_frames_from_video(video_path, transform)
|
121 |
+
frames = frames.to(device)
|
122 |
+
|
123 |
+
# Log input shapes and devices
|
124 |
+
print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}")
|
125 |
+
|
126 |
+
# Perform forward pass through the model
|
127 |
+
with torch.no_grad():
|
128 |
+
output = model(encoding, frames)
|
129 |
+
|
130 |
+
# Apply sigmoid to get probability, then threshold to get prediction
|
131 |
+
prediction = (output.squeeze(-1) > 0.5).float()
|
132 |
+
|
133 |
+
return prediction.item()
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Prediction error: {e}")
|
137 |
+
return "Error during prediction"
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
transform = transforms.Compose([
|
143 |
+
transforms.Resize((224, 224)),
|
144 |
+
transforms.ToTensor(),
|
145 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
146 |
+
])
|
147 |
+
|
148 |
+
|
149 |
+
# Define your video paths and captions
|
150 |
+
video_paths = [
|
151 |
+
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
|
152 |
+
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
|
153 |
+
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
|
154 |
+
]
|
155 |
+
|
156 |
+
video_captions = [
|
157 |
+
"Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok",
|
158 |
+
"New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
|
159 |
+
"all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
|
160 |
+
]
|
161 |
+
|
162 |
+
|
163 |
+
def predict_risk(video_index):
|
164 |
+
video_path = video_paths[video_index]
|
165 |
+
text_input = video_captions[video_index]
|
166 |
+
|
167 |
+
# Make prediction
|
168 |
+
prediction = predict_video(model, video_path, text_input, tokenizer, transform)
|
169 |
+
|
170 |
+
# Return the corresponding label
|
171 |
+
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
|
172 |
+
|
173 |
+
# Interface setup
|
174 |
+
with gr.Blocks() as interface:
|
175 |
+
gr.Markdown("# Risk Behavior Prediction")
|
176 |
+
gr.Markdown("Select a video to classify its behavior as risky or not.")
|
177 |
+
|
178 |
+
# Input option selector
|
179 |
+
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
|
180 |
+
|
181 |
+
# Use function to return URLs which are handled by the Gradio `gr.Video` component
|
182 |
+
def show_selected_video(choice):
|
183 |
+
idx = int(choice.split()[-1]) - 1
|
184 |
+
return video_paths[idx], f"**Caption:** {video_captions[idx]}"
|
185 |
+
|
186 |
+
video_player = gr.Video(width=320, height=240)
|
187 |
+
caption_box = gr.Markdown()
|
188 |
+
|
189 |
+
video_selector.change(
|
190 |
+
fn=show_selected_video,
|
191 |
+
inputs=video_selector,
|
192 |
+
outputs=[video_player, caption_box]
|
193 |
+
)
|
194 |
+
|
195 |
+
# Prediction button and output
|
196 |
+
predict_button = gr.Button("Predict Risk")
|
197 |
+
output_text = gr.Textbox(label="Prediction")
|
198 |
+
|
199 |
+
predict_button.click(
|
200 |
+
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
|
201 |
+
inputs=video_selector,
|
202 |
+
outputs=output_text
|
203 |
+
)
|
204 |
+
|
205 |
+
# Launch the app
|
206 |
+
interface.launch()
|