|
import torch |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
from torchvision import models, transforms |
|
import torch.nn as nn |
|
import os |
|
import json |
|
import cv2 |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
class MultimodalRiskBehaviorModel(nn.Module): |
|
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3): |
|
super(MultimodalRiskBehaviorModel, self).__init__() |
|
|
|
|
|
self.text_model_name = text_model_name |
|
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2) |
|
|
|
|
|
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
visual_feature_dim = self.visual_model.fc.in_features |
|
self.visual_model.fc = nn.Identity() |
|
|
|
|
|
text_feature_dim = self.text_model.config.hidden_size |
|
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
self.fc2 = nn.Linear(hidden_dim, 1) |
|
|
|
def forward(self, encoding, frames): |
|
input_ids = encoding['input_ids'].squeeze(1).to(device) |
|
attention_mask = encoding['attention_mask'].squeeze(1).to(device) |
|
|
|
|
|
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits |
|
frames = frames.to(device) |
|
|
|
batch_size, num_frames, channels, height, width = frames.size() |
|
frames = frames.view(batch_size * num_frames, channels, height, width) |
|
visual_features = self.visual_model(frames) |
|
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1) |
|
|
|
|
|
combined_features = torch.cat((text_features, visual_features), dim=1) |
|
x = self.dropout(torch.relu(self.fc1(combined_features))) |
|
output = torch.sigmoid(self.fc2(x)) |
|
|
|
return output |
|
|
|
def save_pretrained(self, save_directory): |
|
os.makedirs(save_directory, exist_ok=True) |
|
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) |
|
config = { |
|
"text_model_name": self.text_model_name, |
|
"hidden_dim": self.fc1.out_features |
|
} |
|
with open(os.path.join(save_directory, 'config.json'), 'w') as f: |
|
json.dump(config, f) |
|
|
|
@classmethod |
|
def from_pretrained(cls, load_directory, map_location=None): |
|
if os.path.exists(load_directory): |
|
config_path = os.path.join(load_directory, 'config.json') |
|
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin') |
|
|
|
with open(config_path, 'r') as f: |
|
config_dict = json.load(f) |
|
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"]) |
|
state_dict = torch.load(state_dict_path, map_location=map_location) |
|
model.load_state_dict(state_dict) |
|
|
|
else: |
|
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2) |
|
model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size) |
|
model.text_model = hf_model |
|
|
|
return model |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50') |
|
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
|
|
|
|
def load_frames_from_video(video_path, transform, num_frames=10): |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
frame_count = 0 |
|
while frame_count < num_frames: |
|
success, frame = cap.read() |
|
if not success: |
|
break |
|
|
|
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
frame = transform(frame) |
|
frames.append(frame) |
|
frame_count += 1 |
|
cap.release() |
|
|
|
|
|
frames = torch.stack(frames) |
|
frames = frames.unsqueeze(0) |
|
return frames |
|
|
|
def predict_video(model, video_path, text_input, tokenizer, transform): |
|
try: |
|
|
|
model.eval() |
|
|
|
|
|
encoding = tokenizer( |
|
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt' |
|
) |
|
encoding = {key: val.to(device) for key, val in encoding.items()} |
|
|
|
|
|
frames = load_frames_from_video(video_path, transform) |
|
frames = frames.to(device) |
|
|
|
|
|
print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}") |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(encoding, frames) |
|
|
|
|
|
prediction = (output.squeeze(-1) > 0.5).float() |
|
|
|
return prediction.item() |
|
|
|
except Exception as e: |
|
print(f"Prediction error: {e}") |
|
return "Error during prediction" |
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
|
|
video_paths = [ |
|
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM', |
|
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n', |
|
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj' |
|
] |
|
|
|
video_captions = [ |
|
"Everytime i start a diet ูู ู
ุฑุฉ ุฃุญุงูู ุฃุจุฏุฃ ุฑูุฌูู
๐ #dietmemes #funnyvideos #animetiktok", |
|
"New sandwich from burger king ๐๐ #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King", |
|
"all workout guides l!nked in bi0 // honestly huge moment ๐ Iโve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp" |
|
] |
|
|
|
|
|
def predict_risk(video_index): |
|
video_path = video_paths[video_index] |
|
text_input = video_captions[video_index] |
|
|
|
|
|
prediction = predict_video(model, video_path, text_input, tokenizer, transform) |
|
|
|
|
|
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior" |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Risk Behavior Prediction") |
|
gr.Markdown("Select a video to classify its behavior as risky or not.") |
|
|
|
|
|
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video") |
|
|
|
|
|
def show_selected_video(choice): |
|
idx = int(choice.split()[-1]) - 1 |
|
return video_paths[idx], f"**Caption:** {video_captions[idx]}" |
|
|
|
video_player = gr.Video(width=320, height=240) |
|
caption_box = gr.Markdown() |
|
|
|
video_selector.change( |
|
fn=show_selected_video, |
|
inputs=video_selector, |
|
outputs=[video_player, caption_box] |
|
) |
|
|
|
|
|
predict_button = gr.Button("Predict Risk") |
|
output_text = gr.Textbox(label="Prediction") |
|
|
|
predict_button.click( |
|
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1), |
|
inputs=video_selector, |
|
outputs=output_text |
|
) |
|
|
|
|
|
interface.launch() |