PyTorch
BERT_Resnet50 / example_usage.py
Souha-BH's picture
Update example_usage.py
5e252c5 verified
raw
history blame
8.14 kB
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torchvision import models, transforms
import torch.nn as nn
import os
import json
import cv2
from PIL import Image
import gradio as gr
class MultimodalRiskBehaviorModel(nn.Module):
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3):
super(MultimodalRiskBehaviorModel, self).__init__()
# Text model using AutoModelForSequenceClassification
self.text_model_name = text_model_name
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2)
# Visual model (ResNet50)
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
visual_feature_dim = self.visual_model.fc.in_features
self.visual_model.fc = nn.Identity()
# Fusion and classification layer setup
text_feature_dim = self.text_model.config.hidden_size
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, encoding, frames):
input_ids = encoding['input_ids'].squeeze(1).to(device)
attention_mask = encoding['attention_mask'].squeeze(1).to(device)
# Extract text and visual features
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits
frames = frames.to(device)
batch_size, num_frames, channels, height, width = frames.size()
frames = frames.view(batch_size * num_frames, channels, height, width)
visual_features = self.visual_model(frames)
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1)
# Combine and classify
combined_features = torch.cat((text_features, visual_features), dim=1)
x = self.dropout(torch.relu(self.fc1(combined_features)))
output = torch.sigmoid(self.fc2(x))
return output
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
config = {
"text_model_name": self.text_model_name,
"hidden_dim": self.fc1.out_features
}
with open(os.path.join(save_directory, 'config.json'), 'w') as f:
json.dump(config, f)
@classmethod
def from_pretrained(cls, load_directory, map_location=None):
if os.path.exists(load_directory):
config_path = os.path.join(load_directory, 'config.json')
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin')
with open(config_path, 'r') as f:
config_dict = json.load(f)
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"])
state_dict = torch.load(state_dict_path, map_location=map_location)
model.load_state_dict(state_dict)
else:
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2)
model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size)
model.text_model = hf_model
return model
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50')
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Function to load frames from a video
def load_frames_from_video(video_path, transform, num_frames=10):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while frame_count < num_frames: # Limit to a number of frames for efficiency
success, frame = cap.read()
if not success:
break
# Convert frame (NumPy array) to PIL image and apply transformations
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frame = transform(frame)
frames.append(frame)
frame_count += 1
cap.release()
# Stack frames and add batch dimension (1, num_frames, channels, height, width)
frames = torch.stack(frames)
frames = frames.unsqueeze(0) # Add batch dimension
return frames
def predict_video(model, video_path, text_input, tokenizer, transform):
try:
# Set model to evaluation mode
model.eval()
# Tokenize the text input
encoding = tokenizer(
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
)
encoding = {key: val.to(device) for key, val in encoding.items()}
# Load frames from the video
frames = load_frames_from_video(video_path, transform)
frames = frames.to(device)
# Log input shapes and devices
print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}")
# Perform forward pass through the model
with torch.no_grad():
output = model(encoding, frames)
# Apply sigmoid to get probability, then threshold to get prediction
prediction = (output.squeeze(-1) > 0.5).float()
return prediction.item()
except Exception as e:
print(f"Prediction error: {e}")
return "Error during prediction"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define your video paths and captions
video_paths = [
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM',
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n',
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj'
]
video_captions = [
"Everytime i start a diet ูƒู„ ู…ุฑุฉ ุฃุญุงูˆู„ ุฃุจุฏุฃ ุฑูŠุฌูŠู… ๐Ÿ˜“ #dietmemes #funnyvideos #animetiktok",
"New sandwich from burger king ๐Ÿ”๐Ÿ‘‘ #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King",
"all workout guides l!nked in bi0 // honestly huge moment ๐Ÿ˜‚ Iโ€™ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp"
]
def predict_risk(video_index):
video_path = video_paths[video_index]
text_input = video_captions[video_index]
# Make prediction
prediction = predict_video(model, video_path, text_input, tokenizer, transform)
# Return the corresponding label
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior"
# Interface setup
with gr.Blocks() as interface:
gr.Markdown("# Risk Behavior Prediction")
gr.Markdown("Select a video to classify its behavior as risky or not.")
# Input option selector
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video")
# Use function to return URLs which are handled by the Gradio `gr.Video` component
def show_selected_video(choice):
idx = int(choice.split()[-1]) - 1
return video_paths[idx], f"**Caption:** {video_captions[idx]}"
video_player = gr.Video(width=320, height=240)
caption_box = gr.Markdown()
video_selector.change(
fn=show_selected_video,
inputs=video_selector,
outputs=[video_player, caption_box]
)
# Prediction button and output
predict_button = gr.Button("Predict Risk")
output_text = gr.Textbox(label="Prediction")
predict_button.click(
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1),
inputs=video_selector,
outputs=output_text
)
# Launch the app
interface.launch()