Spaces:
Running
Running
# services/huggingface_service.py | |
from .vlm_service import VLMService, ModelType | |
from typing import Dict, Any | |
import aiohttp | |
import base64 | |
import json | |
import time | |
import re | |
import imghdr | |
class HuggingFaceService(VLMService): | |
""" | |
Hugging Face Inference Providers (OpenAI-compatible) service. | |
This class speaks to https://router.huggingface.co/v1/chat/completions | |
so you can call many VLMs with the same payload shape. | |
""" | |
def __init__(self, api_key: str, model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"): | |
super().__init__(f"HF_{model_id.replace('/', '_')}", ModelType.CUSTOM) | |
self.api_key = api_key | |
self.model_id = model_id | |
self.providers_url = "https://router.huggingface.co/v1/chat/completions" | |
def _guess_mime(self, image_bytes: bytes) -> str: | |
kind = imghdr.what(None, h=image_bytes) | |
if kind == "png": | |
return "image/png" | |
if kind in ("jpg", "jpeg"): | |
return "image/jpeg" | |
if kind == "webp": | |
return "image/webp" | |
return "image/jpeg" | |
async def generate_caption( | |
self, | |
image_bytes: bytes, | |
prompt: str, | |
metadata_instructions: str = "", | |
) -> Dict[str, Any]: | |
""" | |
Generate caption using HF Inference Providers (OpenAI-style). | |
""" | |
start_time = time.time() | |
instruction = (prompt or "").strip() | |
if metadata_instructions: | |
instruction += "\n\n" + metadata_instructions.strip() | |
mime = self._guess_mime(image_bytes) | |
data_url = f"data:{mime};base64,{base64.b64encode(image_bytes).decode('utf-8')}" | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json", | |
} | |
# OpenAI-compatible chat payload with one text + one image block. | |
payload = { | |
"model": self.model_id, | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": instruction}, | |
{"type": "image_url", "image_url": {"url": data_url}}, | |
], | |
} | |
], | |
"max_tokens": 512, | |
"temperature": 0.2, | |
} | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
self.providers_url, | |
headers=headers, | |
json=payload, | |
timeout=aiohttp.ClientTimeout(total=180), | |
) as resp: | |
raw_text = await resp.text() | |
if resp.status != 200: | |
# Any non-200 status - throw generic error for fallback handling | |
raise Exception(f"MODEL_UNAVAILABLE: {self.model_name} is currently unavailable (HTTP {resp.status}). Switching to another model.") | |
result = await resp.json() | |
except Exception as e: | |
if "MODEL_UNAVAILABLE" in str(e): | |
raise # Re-raise model unavailable exceptions as-is | |
# Catch any other errors (network, timeout, parsing, etc.) and treat as model unavailable | |
raise Exception(f"MODEL_UNAVAILABLE: {self.model_name} is currently unavailable due to an error. Switching to another model.") | |
# Extract model output (string or list-of-blocks) | |
message = (result.get("choices") or [{}])[0].get("message", {}) | |
content = message.get("content", "") | |
# GLM models sometimes put content in reasoning_content field | |
if not content and message.get("reasoning_content"): | |
content = message.get("reasoning_content", "") | |
if isinstance(content, list): | |
# Some providers may return a list of output blocks (e.g., {"type":"output_text","text":...}) | |
parts = [] | |
for block in content: | |
if isinstance(block, dict): | |
parts.append(block.get("text") or block.get("content") or "") | |
else: | |
parts.append(str(block)) | |
content = "\n".join([p for p in parts if p]) | |
caption = content or "" | |
cleaned = caption.strip() | |
# Strip accidental fenced JSON | |
if cleaned.startswith("```json"): | |
cleaned = re.sub(r"^```json\s*", "", cleaned) | |
cleaned = re.sub(r"\s*```$", "", cleaned) | |
# Best-effort JSON protocol | |
metadata = {} | |
description = "" | |
analysis = cleaned | |
recommended_actions = "" | |
try: | |
parsed = json.loads(cleaned) | |
description = parsed.get("description", "") | |
analysis = parsed.get("analysis", cleaned) | |
recommended_actions = parsed.get("recommended_actions", "") | |
metadata = parsed.get("metadata", {}) or {} | |
except json.JSONDecodeError: | |
# If not JSON, try to extract metadata from GLM thinking format | |
if "<think>" in cleaned: | |
analysis, metadata = self._extract_glm_metadata(cleaned) | |
else: | |
# Fallback: try to extract any structured information | |
analysis = cleaned | |
metadata = {} | |
# Combine all three parts for backward compatibility | |
caption_text = f"Description: {description}\n\nAnalysis: {analysis}\n\nRecommended Actions: {recommended_actions}" | |
# Validate and clean metadata fields with sensible defaults | |
if isinstance(metadata, dict): | |
# Clean EPSG - default to "OTHER" if not in allowed values | |
if metadata.get("epsg"): | |
allowed = {"4326", "3857", "32617", "32633", "32634", "OTHER"} | |
if str(metadata["epsg"]) not in allowed: | |
metadata["epsg"] = "OTHER" | |
else: | |
metadata["epsg"] = "OTHER" # Default when missing | |
# Clean source - default to "OTHER" if not recognized | |
if metadata.get("source"): | |
allowed_sources = {"PDC", "GDACS", "WFP", "GFH", "GGC", "USGS", "OTHER"} | |
if str(metadata["source"]).upper() not in allowed_sources: | |
metadata["source"] = "OTHER" | |
else: | |
metadata["source"] = "OTHER" | |
# Clean event type - default to "OTHER" if not recognized | |
if metadata.get("type"): | |
allowed_types = {"BIOLOGICAL_EMERGENCY", "CHEMICAL_EMERGENCY", "CIVIL_UNREST", | |
"COLD_WAVE", "COMPLEX_EMERGENCY", "CYCLONE", "DROUGHT", "EARTHQUAKE", | |
"EPIDEMIC", "FIRE", "FLOOD", "FLOOD_INSECURITY", "HEAT_WAVE", | |
"INSECT_INFESTATION", "LANDSLIDE", "OTHER", "PLUVIAL", | |
"POPULATION_MOVEMENT", "RADIOLOGICAL_EMERGENCY", "STORM", | |
"TRANSPORTATION_EMERGENCY", "TSUNAMI", "VOLCANIC_ERUPTION"} | |
if str(metadata["type"]).upper() not in allowed_types: | |
metadata["type"] = "OTHER" | |
else: | |
metadata["type"] = "OTHER" | |
# Ensure countries is always a list | |
if not metadata.get("countries") or not isinstance(metadata.get("countries"), list): | |
metadata["countries"] = [] | |
elapsed = time.time() - start_time | |
return { | |
"caption": caption_text, | |
"metadata": metadata, | |
"confidence": None, | |
"processing_time": elapsed, | |
"raw_response": { | |
"model": self.model_id, | |
"response": result, | |
"parsed_successfully": bool(metadata), | |
}, | |
"description": description, | |
"analysis": analysis, | |
"recommended_actions": recommended_actions | |
} | |
def _extract_glm_metadata(self, content: str) -> tuple[str, dict]: | |
""" | |
Extract metadata from GLM thinking format using simple, robust patterns. | |
Focus on extracting what we can and rely on defaults for the rest. | |
""" | |
# Remove <think> tags | |
content = re.sub(r'<think>|</think>', '', content) | |
metadata = {} | |
# Simple extraction - just look for key patterns, don't overthink it | |
# Title: Look for quoted strings after "Maybe" or "Title" | |
title_match = re.search(r'(?:Maybe|Title).*?["\']([^"\']{5,50})["\']', content, re.IGNORECASE) | |
if title_match: | |
metadata["title"] = title_match.group(1).strip() | |
# Source: Look for common source names (WFP, PDC, etc.) | |
source_match = re.search(r'\b(WFP|PDC|GDACS|GFH|GGC|USGS)\b', content, re.IGNORECASE) | |
if source_match: | |
metadata["source"] = source_match.group(1).upper() | |
# Type: Look for disaster types | |
disaster_types = ["EARTHQUAKE", "FLOOD", "CYCLONE", "DROUGHT", "FIRE", "STORM", "TSUNAMI", "VOLCANIC"] | |
for disaster_type in disaster_types: | |
if re.search(rf'\b{disaster_type}\b', content, re.IGNORECASE): | |
metadata["type"] = disaster_type | |
break | |
# Countries: Look for 2-letter country codes | |
country_matches = re.findall(r'\b([A-Z]{2})\b', content) | |
valid_countries = [] | |
for match in country_matches: | |
# Basic validation - exclude common false positives | |
if match not in ["SO", "IS", "OR", "IN", "ON", "TO", "OF", "AT", "BY", "NO", "GO", "UP", "US"]: | |
valid_countries.append(match) | |
if valid_countries: | |
metadata["countries"] = list(set(valid_countries)) # Remove duplicates | |
# EPSG: Look for 4-digit numbers that could be EPSG codes | |
epsg_match = re.search(r'\b(4326|3857|32617|32633|32634)\b', content) | |
if epsg_match: | |
metadata["epsg"] = epsg_match.group(1) | |
# For caption, just use the first part before metadata discussion | |
lines = content.split('\n') | |
caption_lines = [] | |
for line in lines: | |
if any(keyword in line.lower() for keyword in ['metadata:', 'now for the metadata', 'let me double-check']): | |
break | |
caption_lines.append(line) | |
caption_text = '\n'.join(caption_lines).strip() | |
if not caption_text: | |
caption_text = content | |
return caption_text, metadata | |
# --- Generic Model Wrapper for Dynamic Registration --- | |
class ProvidersGenericVLMService(HuggingFaceService): | |
""" | |
Generic wrapper so you can register ANY Providers VLM by model_id from config. | |
Example: | |
ProvidersGenericVLMService(HF_TOKEN, "Qwen/Qwen2.5-VL-32B-Instruct", "QWEN2_5_VL_32B") | |
""" | |
def __init__(self, api_key: str, model_id: str, public_name: str | None = None): | |
super().__init__(api_key, model_id) | |
# Use a human-friendly stable name that your UI/DB will reference | |
self.model_name = public_name or model_id.replace("/", "_").upper() | |
self.model_type = ModelType.CUSTOM | |
class ProvidersGenericVLMService(HuggingFaceService): | |
""" | |
Generic wrapper so you can register ANY Providers VLM by model_id from config. | |
Example: | |
ProvidersGenericVLMService(HF_TOKEN, "Qwen/Qwen2.5-VL-32B-Instruct", "QWEN2_5_VL_32B") | |
""" | |
def __init__(self, api_key: str, model_id: str, public_name: str | None = None): | |
super().__init__(api_key, model_id) | |
# Use a human-friendly stable name that your UI/DB will reference | |
self.model_name = public_name or model_id.replace("/", "_").upper() | |
self.model_type = ModelType.CUSTOM | |