devangoyal's picture
Deploy
d898dac
import torch
from models import MultimodalSentimentModel
import os
import cv2
import numpy as np
import subprocess
import torchaudio
from transformers import AutoTokenizer
import whisper
import sys
EMOTION_MAP = {0: "anger", 1: "disgust", 2: "fear",
3: "joy", 4: "neutral", 5: "sadness", 6: "surprise"}
SENTIMENT_MAP = {0: "negative", 1: "neutral", 2: "positive"}
def install_ffmpeg():
print("Starting Ffmpeg installation...")
subprocess.check_call([sys.executable, "-m", "pip",
"install", "--upgrade", "pip"])
subprocess.check_call([sys.executable, "-m", "pip",
"install", "--upgrade", "setuptools"])
try:
subprocess.check_call([sys.executable, "-m", "pip",
"install", "ffmpeg-python"])
print("Installed ffmpeg-python successfully")
except subprocess.CalledProcessError as e:
print("Failed to install ffmpeg-python via pip")
try:
subprocess.check_call([
"wget",
"https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz",
"-O", "/tmp/ffmpeg.tar.xz"
])
subprocess.check_call([
"tar", "-xf", "/tmp/ffmpeg.tar.xz", "-C", "/tmp/"
])
result = subprocess.run(
["find", "/tmp", "-name", "ffmpeg", "-type", "f"],
capture_output=True,
text=True
)
ffmpeg_path = result.stdout.strip()
subprocess.check_call(["cp", ffmpeg_path, "/usr/local/bin/ffmpeg"])
subprocess.check_call(["chmod", "+x", "/usr/local/bin/ffmpeg"])
print("Installed static FFmpeg binary successfully")
except Exception as e:
print(f"Failed to install static FFmpeg: {e}")
try:
result = subprocess.run(["ffmpeg", "-version"],
capture_output=True, text=True, check=True)
print("FFmpeg version:")
print(result.stdout)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
print("FFmpeg installation verification failed")
return False
class VideoProcessor:
def process_video(self, video_path):
cap = cv2.VideoCapture(video_path)
frames = []
try:
if not cap.isOpened():
raise ValueError(f"Video not found: {video_path}")
# Try and read first frame to validate video
ret, frame = cap.read()
if not ret or frame is None:
raise ValueError(f"Video not found: {video_path}")
# Reset index to not skip first frame
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
while len(frames) < 30 and cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, (224, 224))
frame = frame / 255.0
frames.append(frame)
except Exception as e:
raise ValueError(f"Video error: {str(e)}")
finally:
cap.release()
if (len(frames) == 0):
raise ValueError("No frames could be extracted")
# Pad or truncate frames
if len(frames) < 30:
frames += [np.zeros_like(frames[0])] * (30 - len(frames))
else:
frames = frames[:30]
# Before permute: [frames, height, width, channels]
# After permute: [frames, channels, height, width]
return torch.FloatTensor(np.array(frames)).permute(0, 3, 1, 2)
class AudioProcessor:
def extract_features(self, video_path, max_length=300):
audio_path = video_path.replace('.mp4', '.wav')
try:
subprocess.run([
'ffmpeg',
'-i', video_path,
'-vn',
'-acodec', 'pcm_s16le',
'-ar', '16000',
'-ac', '1',
audio_path
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=16000,
n_mels=64,
n_fft=1024,
hop_length=512
)
mel_spec = mel_spectrogram(waveform)
# Normalize
mel_spec = (mel_spec - mel_spec.mean()) / mel_spec.std()
if mel_spec.size(2) < 300:
padding = 300 - mel_spec.size(2)
mel_spec = torch.nn.functional.pad(mel_spec, (0, padding))
else:
mel_spec = mel_spec[:, :, :300]
return mel_spec
except subprocess.CalledProcessError as e:
raise ValueError(f"Audio extraction error: {str(e)}")
except Exception as e:
raise ValueError(f"Audio error: {str(e)}")
finally:
if os.path.exists(audio_path):
os.remove(audio_path)
class VideoUtteranceProcessor:
def __init__(self):
self.video_processor = VideoProcessor()
self.audio_processor = AudioProcessor()
def extract_segment(self, video_path, start_time, end_time, temp_dir="/tmp"):
os.makedirs(temp_dir, exist_ok=True)
segment_path = os.path.join(
temp_dir, f"segment_{start_time}_{end_time}.mp4")
subprocess.run([
"ffmpeg", "-i", video_path,
"-ss", str(start_time),
"-to", str(end_time),
"-c:v", "libx264",
"-c:a", "aac",
"-y",
segment_path
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if not os.path.exists(segment_path) or os.path.getsize(segment_path) == 0:
raise ValueError("Segment extraction failed: " + segment_path)
return segment_path
def model_fn(model_dir):
# Load the model for inference
if not install_ffmpeg():
raise RuntimeError(
"FFmpeg installation failed - required for inference")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalSentimentModel().to(device)
model_path = os.path.join(model_dir, 'model.pth')
if not os.path.exists(model_path):
model_path = os.path.join(model_dir, "saved_models", 'checkpoint.pth')
if not os.path.exists(model_path):
raise FileNotFoundError(
"Model file not found in path " + model_path)
print("Loading model from path: " + model_path)
model.load_state_dict(torch.load(
model_path, map_location=device, weights_only=True))
model.eval()
return {
'model': model,
'tokenizer': AutoTokenizer.from_pretrained('bert-base-uncased'),
'transcriber': whisper.load_model(
"base",
device="cpu" if device.type == "cpu" else device,
),
'device': device
}
def predict_fn(input_data, model_dict):
model = model_dict['model']
tokenizer = model_dict['tokenizer']
device = model_dict['device']
video_path = input_data['video_path']
result = model_dict['transcriber'].transcribe(
video_path, word_timestamps=True)
utterance_processor = VideoUtteranceProcessor()
predictions = []
for segment in result["segments"]:
try:
segment_path = utterance_processor.extract_segment(
video_path,
segment["start"],
segment["end"]
)
video_frames = utterance_processor.video_processor.process_video(
segment_path)
audio_features = utterance_processor.audio_processor.extract_features(
segment_path)
text_inputs = tokenizer(
segment["text"],
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
# Move to device
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
video_frames = video_frames.unsqueeze(0).to(device)
audio_features = audio_features.unsqueeze(0).to(device)
# Get predictions
with torch.inference_mode():
outputs = model(text_inputs, video_frames, audio_features)
emotion_probs = torch.softmax(outputs["emotions"], dim=1)[0]
sentiment_probs = torch.softmax(
outputs["sentiments"], dim=1)[0]
emotion_values, emotion_indices = torch.topk(emotion_probs, 3)
sentiment_values, sentiment_indices = torch.topk(
sentiment_probs, 3)
predictions.append({
"start_time": segment["start"],
"end_time": segment["end"],
"text": segment["text"],
"emotions": [
{"label": EMOTION_MAP[idx.item()], "confidence": conf.item()} for idx, conf in zip(emotion_indices, emotion_values)
],
"sentiments": [
{"label": SENTIMENT_MAP[idx.item()], "confidence": conf.item()} for idx, conf in zip(sentiment_indices, sentiment_values)
]
})
except Exception as e:
print("Segment failed inference: " + str(e))
finally:
# Cleanup
if os.path.exists(segment_path):
os.remove(segment_path)
return {"utterances": predictions}
def process_local_video(video_path, model_dir="."):
model_dict = model_fn(model_dir)
input_data = {'video_path': video_path}
predictions = predict_fn(input_data, model_dict)
for utterance in predictions["utterances"]:
print("\nUtterance:")
print(f"""Start: {utterance['start_time']}s, End: {
utterance['end_time']}s""")
print(f"Text: {utterance['text']}")
print("\n Top Emotions:")
for emotion in utterance['emotions']:
print(f"{emotion['label']}: {emotion['confidence']:.2f}")
print("\n Top Sentiments:")
for sentiment in utterance['sentiments']:
print(f"{sentiment['label']}: {sentiment['confidence']:.2f}")
print("-"*50)
if __name__ == "__main__":
process_local_video("./dia2_utt3.mp4")