svision / app.py
VeuReu's picture
Update app.py
a9d19b3 verified
raw
history blame
24.7 kB
"""
# ==============================================================================
# Vision-Language and Face Recognition Utilities
# ==============================================================================
This module provides helper functions, lazy-loading mechanisms, and
API endpoint wrappers for multimodal inference, face recognition, and
video scene extraction.
It includes functionality for:
- Lazy initialization of heavyweight models (vision-language and face models)
- Image and video preprocessing
- Multimodal inference with configurable parameters (token limits, temperature)
- Facial embedding generation
- Scene extraction from video files
- Gradio UI components and endpoint definitions for user interaction
All functions and utilities are designed to be:
- Reusable and cache heavy models to reduce repeated loading
- Compatible with GPU/CPU execution
- Stateless and safe to call concurrently from multiple requests
- Modular, separating model logic from endpoint and UI handling
This module serves as the core interface layer between client-facing
APIs/UI and the underlying machine learning models.
# ==============================================================================
"""
# Standard library
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
# Third-party libraries
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from facenet_pytorch import InceptionResnetV1, MTCNN
from PIL import Image
from scenedetect import SceneManager, VideoManager
from scenedetect.detectors import ContentDetector
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
'''
# ==============================================================================
# Lazy-loading utilities for vision-language and face recognition models
# ==============================================================================
This module provides on-demand initialization of heavyweight components, including:
- MTCNN: Face detector used to locate and align faces.
- FaceNet (InceptionResnetV1): Generates 512-dimensional facial embeddings.
- LLaVA OneVision: Vision-language model for multimodal inference.
By loading models lazily and caching them in global variables, the system avoids
unnecessary reinitialization and reduces startup time, improving performance in
production environments such as FastAPI services, Docker deployments, and
Hugging Face Spaces.
# ==============================================================================
'''
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-vision")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_model = None
_processor = None
_mtcnn = None
_facenet = None
def _load_face_models() -> Tuple[MTCNN, InceptionResnetV1]:
"""
Lazily loads and initializes the facial detection and facial embedding models.
This function loads:
- **MTCNN**: Used for face detection and cropping.
- **InceptionResnetV1 (FaceNet)**: Used to generate 512-dimensional face embeddings.
Both models are loaded only once and stored in global variables to avoid
unnecessary re-initialization. They are automatically placed on GPU if available,
otherwise CPU is used.
Returns:
Tuple[MTCNN, InceptionResnetV1]: A tuple containing the initialized
face detection model and the face embedding model.
"""
global _mtcnn, _facenet
if _mtcnn is None or _facenet is None:
device = DEVICE if DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
_mtcnn = MTCNN(image_size=160, margin=0, post_process=True, device=device)
_facenet = InceptionResnetV1(pretrained="vggface2").eval().to(device)
return _mtcnn, _facenet
def _lazy_load() -> Tuple[LlavaOnevisionForConditionalGeneration, AutoProcessor]:
"""
Lazily loads the vision-language model and its processor.
This function performs a first-time load of:
- **AutoProcessor**: Handles preprocessing of text and images for the model.
- **LlavaOnevisionForConditionalGeneration**: The main multimodal model used
for inference and text generation.
The model is moved to GPU if available and configured with:
- The appropriate floating-point precision (`float16` or `float32`)
- Low memory usage mode
- SafeTensors loading enabled
Both components are cached in global variables to ensure subsequent calls
reuse the loaded instances without reinitialization.
Returns:
Tuple[LlavaOnevisionForConditionalGeneration, AutoProcessor]:
The loaded model and processor ready for inference.
"""
global _model, _processor
if _model is None or _processor is None:
_processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
_model = LlavaOnevisionForConditionalGeneration.from_pretrained(
MODEL_ID,
dtype=DTYPE,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_safetensors=True,
device_map=None,
)
_model.to(DEVICE)
return _model, _processor
'''
# ==============================================================================
# Auxiliary Model Loading Utilities for API Endpoints
# ==============================================================================
This module contains helper functions used internally by the API endpoints to
efficiently load and manage heavy machine learning components. These utilities
handle on-demand initialization ("lazy loading") of both the vision-language
model (LLaVA OneVision) and the facial detection/embedding models (MTCNN and
FaceNet).
The goal of this helper block is to:
- Avoid repeated loading of large models across requests.
- Reduce GPU/CPU memory pressure by reusing cached instances.
- Provide clean separation between endpoint logic and model-handling logic.
- Improve performance and stability in production environments
(FastAPI, Docker, Hugging Face Spaces).
All functions here are intended for internal use and should be called by
endpoint handlers when a model is required for a given request.
# ==============================================================================
'''
@spaces.GPU
def _infer_one(
image: Image.Image,
text: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
context: Optional[Dict] = None,
) -> str:
"""
Run a single multimodal inference step using the LLaVA OneVision model.
This function:
- Optionally downsizes the input image to reduce GPU memory consumption.
- Loads the model and processor through lazy initialization.
- Builds the final prompt by applying the chat template and injecting optional context.
- Performs autoregressive generation with configurable token and temperature settings.
- Returns the decoded textual output.
Args:
image (Image.Image): Input PIL image used for multimodal conditioning.
text (str): User-provided instruction or query.
max_new_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature controlling output randomness.
context (Optional[Dict]): Additional context injected into the prompt.
Returns:
str: The generated textual response.
"""
image.thumbnail((1024, 1024))
model, processor = _lazy_load()
prompt = processor.apply_chat_template(_compose_prompt(text, context), add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(DEVICE, dtype=DTYPE)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
)
return processor.decode(out[0], skip_special_tokens=True).strip()
@spaces.GPU
def _get_face_embedding(
image: Image.Image
) -> list[float] | None:
"""
Generate a FaceNet embedding for a single face in an image.
Args:
image (Image.Image): A PIL Image containing a face.
Returns:
list[float] | None: Normalized embedding vector for the detected face,
or None if no face is detected or an error occurs.
"""
try:
mtcnn, facenet = _load_face_models()
# Detect and extract face
face = mtcnn(image)
if face is None:
return None
# FaceNet expects tensor of shape (1,3,160,160)
device = DEVICE if DEVICE == "cuda" and torch.cuda.is_available() else "cpu"
face = face.unsqueeze(0).to(device)
# Get embedding
with torch.no_grad():
emb = facenet(face).cpu().numpy()[0]
# Normalize embedding
emb = emb / np.linalg.norm(emb)
return emb.astype(float).tolist()
except Exception as e:
print(f"Face embedding failed: {e}")
return None
@spaces.GPU
def _get_scenes_extraction(
video_file: str,
threshold: float,
offset_frames: int,
crop_ratio: float
) -> Tuple[List[Image.Image], List[Dict]] | None:
"""
Extracts scenes from a video and returns cropped images along with information about each scene.
Args:
video_file (str): Path to the video file.
threshold (float): Threshold for scene detection.
offset_frames (int): Frame offset from the start of each scene.
crop_ratio (float): Central crop ratio for each frame.
Returns:
Tuple[List[Image.Image], List[Dict]] | None: List of scene images and list of scene information,
or (None, None) if an error occurs.
"""
try:
# Initialize video and scene managers
video_manager = VideoManager([video_file])
scene_manager = SceneManager()
scene_manager.add_detector(ContentDetector(threshold=threshold))
video_manager.start()
scene_manager.detect_scenes(video_manager)
scene_list = scene_manager.get_scene_list()
cap = cv2.VideoCapture(video_file)
images: List[Image.Image] = []
scene_info: List[Dict] = []
for i, (start_time, end_time) in enumerate(scene_list):
frame_number = int(start_time.get_frames()) + offset_frames
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
if not ret:
continue
h, w = frame.shape[:2]
# Central crop of the frame
ch, cw = int(h * crop_ratio), int(w * crop_ratio)
cropped_frame = frame[ch:h-ch, cw:w-cw]
# Convert to RGB and save as an image
img_rgb = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
images.append(Image.fromarray(img_rgb))
# Store scene information
scene_info.append({
"index": i + 1,
"start": start_time.get_seconds(),
"end": end_time.get_seconds()
})
cap.release()
return images, scene_info
except Exception as e:
print("Error in scenes_extraction:", e)
return None, None
"""
# ==============================================================================
# API Helpers
# ==============================================================================
Collection of public-facing API endpoints used by the application.
This section exposes functions that process incoming requests,
perform validation, interact with the model inference helpers,
and return structured responses. Each endpoint is designed to be
stateless, deterministic, and safe to call from external clients.
Endpoints in this module typically:
- Receive raw data (images, text, base64-encoded content, etc.)
- Preprocess inputs before forwarding them to internal inference utilities
- Handle optional parameters such as temperature or token limits
- Return JSON-serializable dictionaries as responses
The functions below constitute the interface layer between users
and the underlying model logic implemented in the helper utilities.
# ==============================================================================
"""
def describe_raw(image: Image.Image, text: str = "Describe la imagen con detalle.",
max_new_tokens: int = 256, temperature: float = 0.7) -> Dict[str, str]:
"""
Endpoint to generate a detailed description of an input image.
This function receives an image and an optional text prompt, then forwards
the request to the internal inference helper `_infer_one`. It returns a JSON-
serializable dictionary containing the generated text description.
Parameters
----------
image : PIL.Image.Image
The input image to be analyzed and described.
text : str, optional
Instruction or prompt for the model guiding how the image should be described.
Defaults to a general "describe in detail" prompt (in Spanish).
max_new_tokens : int, optional
Maximum number of tokens the model is allowed to generate. Default is 256.
temperature : float, optional
Sampling temperature controlling randomness of the output. Default is 0.7.
Returns
-------
Dict[str, str]
A dictionary with a single key `"text"` containing the generated description.
"""
result = _infer_one(image, text, max_new_tokens, temperature, context=None)
return {"text": result}
def describe_batch(
images: List[Image.Image],
context_json: str,
max_new_tokens: int = 256,
temperature: float = 0.7
) -> List[str]:
"""
Batch endpoint for the image description engine.
This endpoint receives a list of images along with an optional JSON-formatted
context, and returns a list of textual descriptions generated by the model.
Each image is processed individually using the internal `_infer_one` function,
optionally incorporating the context into the prompt.
Args:
images (List[Image.Image]):
A list of PIL Image objects to describe.
context_json (str):
A JSON-formatted string providing additional context for the prompt.
If empty or invalid, no context will be used.
max_new_tokens (int, optional):
Maximum number of tokens to generate per image. Defaults to 256.
temperature (float, optional):
Sampling temperature controlling text randomness. Defaults to 0.7.
Returns:
List[str]: A list of text descriptions, one for each input image, in order.
"""
try:
context = json.loads(context_json) if context_json else None
except Exception:
context = None
outputs: List[str] = []
for img in images:
outputs.append(_infer_one(img, text="Describe la imagen con detalle.", max_new_tokens=max_new_tokens,
temperature=temperature, context=context))
return outputs
def face_image_embedding(image: Image.Image) -> List[float] | None:
"""
Endpoint to generate a face embedding for a given image.
This function wraps the core `_get_face_embedding` logic for use in endpoints.
The MTCNN and FaceNet models must be preloaded before calling this function.
Args:
image (Image.Image): Input image containing a face.
mtcnn (MTCNN): Preloaded MTCNN face detector.
facenet (InceptionResnetV1): Preloaded FaceNet model.
Returns:
list[float] | None: Normalized embedding vector or None if no face detected.
"""
return _get_face_embedding(image)
def scenes_extraction(
video_file: str,
threshold: float,
offset_frames: int,
crop_ratio: float
) -> Tuple[List[Image.Image], List[Dict]] | None:
"""
Endpoint wrapper for extracting scenes from a video.
This function acts as a wrapper around the internal `_get_scenes_extraction` function.
It handles a video file provided as a string path (as Gradio temporarily saves uploaded files)
and returns the extracted scene images along with scene metadata.
Args:
video_file (str): Path to the uploaded video file.
threshold (float): Threshold for scene detection.
offset_frames (int): Frame offset from the start of each detected scene.
crop_ratio (float): Central crop ratio to apply to each extracted frame.
Returns:
Tuple[List[Image.Image], List[Dict]] | None: A tuple containing:
- A list of PIL Images representing each extracted scene.
- A list of dictionaries with scene information (index, start time, end time).
Returns (None, None) if an error occurs during extraction.
"""
return _get_scenes_extraction(video_file, threshold, offset_frames, crop_ratio)
@spaces.GPU
def describe_list_images(
images: List[Image.Image]
) -> List[str]:
"""
Generate brief visual descriptions for a list of PIL Images using Salamandra Vision.
Args:
images (List[Image.Image]): List of PIL Image objects to describe.
Returns:
List[str]: List of descriptions, one per image.
"""
# Load the Salamandra Vision model
path_model = "BSC-LT/salamandra-7b-vision"
processor = AutoProcessor.from_pretrained(path_model)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
path_model,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to("cuda")
# System prompt for image description
sys_prompt = (
"You are an expert in visual storytelling. "
"Describe the image very briefly and simply in Catalan, "
"explaining only the main action seen. "
"Respond with a single short sentence (maximum 10–20 words), "
"without adding unnecessary details or describing the background."
)
all_results = []
for img in images:
batch = [img]
# Create the conversation template
conversation = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": [
{"type": "image", "image": batch[0]},
{"type": "text", "text": (
"Describe the image very briefly and simply in Catalan."
)}
]}
]
prompt_batch = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Prepare inputs for the model
inputs = processor(images=batch, text=prompt_batch, return_tensors="pt")
for k, v in inputs.items():
if v.dtype.is_floating_point:
inputs[k] = v.to("cuda", torch.float16)
else:
inputs[k] = v.to("cuda")
# Generate the description
output = model.generate(**inputs, max_new_tokens=1024)
text = processor.decode(output[0], skip_special_tokens=True)
lines = text.split("\n")
# Extract the assistant's answer
desc = ""
for i, line in enumerate(lines):
if line.lower().startswith(" assistant"):
desc = "\n".join(lines[i+1:]).strip()
break
print("====================")
print(desc)
all_results.append(desc)
return all_results
"""
# ==============================================================================
# UI & Endpoints
# ==============================================================================
Collection of Gradio interface elements and API endpoints used by the application.
This section defines the user-facing interface for Salamandra Vision 7B,
allowing users to interact with the model through images, text prompts,
video uploads, and batch operations.
The components and endpoints in this module typically:
- Accept images, text, or video files from the user
- Apply optional parameters such as temperature, token limits, or crop ratios
- Preprocess inputs and invoke internal inference or utility functions
- Return structured outputs, including text descriptions, JSON metadata,
or image galleries
All endpoints are designed to be stateless, safe for concurrent calls,
and compatible with both interactive UI usage and programmatic API access.
# ==============================================================================
"""
def _compose_prompt(user_text: str, context: Optional[Dict] = None) -> List[Dict]:
"""
Build the chat template with an image, text, and optional context.
Args:
user_text (str): Text provided by the user.
context (Optional[Dict]): Optional additional context.
Returns:
List[Dict]: A conversation template for the model, including the image and text.
"""
ctx_txt = ""
if context:
try:
# Keep context brief and clean
ctx_txt = "\n\nAdditional context:\n" + json.dumps(context, ensure_ascii=False)[:2000]
except Exception:
pass
user_txt = (user_text or "Describe the image in detail.") + ctx_txt
convo = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_txt},
],
}
]
return convo
with gr.Blocks(title="Salamandra Vision 7B · ZeroGPU") as demo:
gr.Markdown("## Salamandra-Vision 7B · ZeroGPU\nImage + text → description.")
with gr.Row():
with gr.Column():
in_img = gr.Image(label="Image", type="pil")
in_txt = gr.Textbox(label="Text/prompt", value="Describe the image in detail (ES/CA).")
max_new = gr.Slider(16, 1024, value=256, step=16, label="max_new_tokens")
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Textbox(label="Description", lines=18)
# Single image inference
btn.click(_infer_one, [in_img, in_txt, max_new, temp], out, api_name="describe", concurrency_limit=1)
# Batch API for engine (Gradio Client): images + context_json → list[str]
batch_in_images = gr.Gallery(label="Batch images", show_label=False, columns=4, height="auto")
batch_context = gr.Textbox(label="context_json", value="{}", lines=4)
batch_max = gr.Slider(16, 1024, value=256, step=16, label="max_new_tokens")
batch_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
batch_btn = gr.Button("Describe batch")
batch_out = gr.JSON(label="Descriptions (list)")
# Note: Gradio Gallery returns paths/objects; the client is used to load files
batch_btn.click(
describe_batch,
[batch_in_images, batch_context, batch_max, batch_temp],
batch_out,
api_name="predict",
concurrency_limit=1
)
# Facial embedding section
with gr.Row():
face_img = gr.Image(label="Image for facial embedding", type="pil")
face_btn = gr.Button("Get facial embedding")
face_out = gr.JSON(label="Facial embedding (vector)")
face_btn.click(face_image_embedding, [face_img], face_out, api_name="face_image_embedding", concurrency_limit=1)
# Video scene extraction section
with gr.Row():
video_file = gr.Video(label="Upload a video")
threshold = gr.Slider(0.0, 100.0, value=30.0, step=1.0, label="Threshold")
offset_frames = gr.Slider(0, 30, value=5, step=1, label="Offset frames")
crop_ratio = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Crop ratio")
scenes_btn = gr.Button("Extract scenes")
scenes_gallery_out = gr.Gallery(label="Scene keyframes", show_label=False, columns=4, height="auto")
scenes_info_out = gr.JSON(label="Scene information")
# Bind the scene extraction function
scenes_btn.click(
scenes_extraction,
inputs=[video_file, threshold, offset_frames, crop_ratio],
outputs=[scenes_gallery_out, scenes_info_out],
api_name="scenes_extraction",
concurrency_limit=1
)
# List image description with Salamandra Vision
with gr.Row():
img_input = gr.Gallery(label="List images", show_label=False, columns=4, height="auto")
describe_btn = gr.Button("Generate descriptions")
desc_output = gr.Textbox(label="Image descriptions", lines=10)
describe_btn.click(
fn=lambda imgs: describe_list_images([img for img in imgs if isinstance(img, Image.Image)])
if imgs else ["No images uploaded."],
inputs=[img_input],
outputs=desc_output,
api_name="describe_images",
concurrency_limit=1
)
demo.queue(max_size=16).launch(show_error=True)