File size: 7,606 Bytes
89ce7bf
 
 
 
89b0c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89ce7bf
89b0c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89ce7bf
89b0c64
89ce7bf
89b0c64
89ce7bf
89b0c64
 
 
 
 
 
89ce7bf
89b0c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import tempfile # For temporary file handling

# --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
# This is crucial because we need the model class definition to load weights.
class SmallVideoClassifier(torch.nn.Module):
    def __init__(self, num_classes=2, num_frames=8):
        super(SmallVideoClassifier, self).__init__()
        from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
        try:
            weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        except Exception:
            print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
            weights = None

        self.feature_extractor = mobilenet_v3_small(weights=weights)
        self.feature_extractor.classifier = torch.nn.Identity()
        self.num_spatial_features = 576
        self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.num_spatial_features, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(512, num_classes)
        )

    def forward(self, pixel_values):
        batch_size, num_frames, channels, height, width = pixel_values.shape
        x = pixel_values.view(batch_size * num_frames, channels, height, width)
        spatial_features = self.feature_extractor(x)
        spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
        temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
        logits = self.classifier(temporal_features)
        return logits

# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"

print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
    model_config = json.load(f)

NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)

CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
    print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")

device = torch.device("cpu")
print(f"Using device: {device}")

model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)

print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval()

print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")

# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- 4. Gradio Inference Function ---
def predict_video(video_path):
    if video_path is None:
        return None

    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}.")
        raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    # Ensure FPS is not zero to avoid division by zero errors, default to 25 if needed
    if fps <= 0:
        fps = 25.0
        print(f"Warning: Original video FPS was 0 or less, defaulting to {fps}.")

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
    output_video_path = temp_output_file.name
    temp_output_file.close()

    # --- CHANGED: Use XVID codec for better browser compatibility ---
    # This might prevent Gradio's internal re-encoding.
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    print(f"Processing video: {video_path}")
    print(f"Total frames: {total_frames}, FPS: {fps}")
    print(f"Output video will be saved to: {output_video_path}")

    frame_buffer = []
    current_prediction_label = "Processing..."

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_idx += 1
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)
        
        processed_frame = transform(pil_image)
        frame_buffer.append(processed_frame)

        if len(frame_buffer) == NUM_FRAMES:
            input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)

            with torch.no_grad():
                outputs = model(input_tensor)
                probabilities = torch.softmax(outputs, dim=1)
                predicted_class_idx = torch.argmax(probabilities, dim=1).item()
                current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
            
            frame_buffer = [] 
            # If you want a sliding window, you would do something like:
            # frame_buffer = frame_buffer[int(NUM_FRAMES * 0.5):] # Slide by half the window size

        # Draw prediction text on the current frame
        # Ensure text color is clearly visible (e.g., white or bright green)
        # Add a black outline for better readability
        text_color = (0, 255, 0) # Green (BGR format for OpenCV)
        text_outline_color = (0, 0, 0) # Black
        font_scale = 1.0 # Increased font size
        font_thickness = 2
        
        # Draw outline first for better readability
        cv2.putText(frame, current_prediction_label, (10, 40), # Slightly lower position
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
        # Draw actual text
        cv2.putText(frame, current_prediction_label, (10, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)

        out.write(frame)

    cap.release()
    out.release()
    print(f"Video processing complete. Output saved to: {output_video_path}")
    
    return output_video_path

# --- 5. Gradio Interface Setup ---
iface = gr.Interface(
    fn=predict_video,
    inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
    outputs=gr.Video(label="Processed Video with Predictions"),
    title="Real-time Violence Detection with SmallVideoClassifier",
    description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
    allow_flagging="never",
    examples=[
        # Add example videos here for easier testing and demonstration
        # E.g., a sample video that's publicly accessible:
        # "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
    ]
)

iface.launch()