|
import gradio as gr |
|
import os |
|
import logging |
|
import json |
|
from huggingface_hub import InferenceClient |
|
from huggingface_hub.utils import HfHubHTTPError |
|
import traceback |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
HF_TARGET_TOKEN = os.environ.get("HF_API_TOKEN") |
|
if not HF_TARGET_TOKEN: |
|
logger.error("CRITICAL: HF_API_TOKEN secret not found in Space environment variables!") |
|
|
|
|
|
target_client = None |
|
initialization_error = None |
|
try: |
|
|
|
if HF_TARGET_TOKEN: |
|
target_client = InferenceClient(token=HF_TARGET_TOKEN) |
|
logger.info("Target InferenceClient initialized.") |
|
else: |
|
|
|
initialization_error = "Service Unavailable: Proxy configuration error (Missing Token)." |
|
logger.error(initialization_error) |
|
except Exception as e: |
|
initialization_error = f"Failed to initialize target InferenceClient: {e}" |
|
logger.error(initialization_error) |
|
target_client = None |
|
|
|
|
|
def proxy_inference(request_data: dict): |
|
""" |
|
Gradio function to handle inference requests. |
|
Expects a dictionary (from gr.JSON input) like: |
|
{ |
|
"imageDataUrl": "data:image/...", |
|
"candidate_labels": ["label1", "label2", ...] |
|
} |
|
Returns a dictionary (for gr.JSON output) like: |
|
{"result": [...]} or {"error": "...", "details": "..."} |
|
""" |
|
logger.info(f"Received request data via Gradio function: {request_data}") |
|
|
|
if initialization_error: |
|
logger.error(f"Returning initialization error: {initialization_error}") |
|
|
|
return {"setup_error": initialization_error} |
|
if not target_client: |
|
|
|
logger.error("Target client not available.") |
|
return {"error": "Configuration Error", "details": "Target client not initialized."} |
|
if not isinstance(request_data, dict): |
|
logger.error(f"Invalid input type: expected dict, got {type(request_data)}") |
|
return {"error": "Bad Request", "details": "Input must be a JSON object."} |
|
|
|
|
|
image_data_url = request_data.get("imageDataUrl") |
|
candidate_labels = request_data.get("candidate_labels", ["person", "car", "building", "animal", "tree"]) |
|
|
|
if not image_data_url or not isinstance(image_data_url, str) or not image_data_url.startswith('data:image'): |
|
logger.error("Missing or invalid 'imageDataUrl' in request.") |
|
return {"error": "Bad Request", "details": "Missing or invalid 'imageDataUrl'."} |
|
if not isinstance(candidate_labels, list): |
|
logger.error("Invalid 'candidate_labels', must be a list.") |
|
return {"error": "Bad Request", "details": "'candidate_labels' must be a list."} |
|
|
|
|
|
logger.info(f"Image URL prefix: {image_data_url[:70]}...") |
|
logger.info(f"Labels: {candidate_labels}") |
|
|
|
try: |
|
|
|
logger.info("Calling target_client.zero_shot_image_classification...") |
|
inference_output = target_client.zero_shot_image_classification( |
|
image=image_data_url, |
|
candidate_labels=candidate_labels |
|
) |
|
logger.info(f"Successfully received response from target API.") |
|
|
|
return {"result": inference_output} |
|
|
|
except HfHubHTTPError as e: |
|
status_code = e.response.status_code if hasattr(e, 'response') else 500 |
|
request_id = e.request_id |
|
error_detail = str(e) |
|
|
|
if hasattr(e, 'response'): |
|
try: |
|
error_data = e.response.json() |
|
error_detail = error_data.get("error", str(e)) |
|
except: pass |
|
logger.error(f"HTTP Error from target HF API: Status={status_code}, RequestID={request_id}, Error={error_detail}") |
|
|
|
return { |
|
"error": f"Target API Error (Status {status_code})", |
|
"details": error_detail, |
|
"request_id": request_id |
|
} |
|
except Exception as e: |
|
error_detail = str(e) |
|
logger.error(f"Unexpected error in proxy function: {error_detail}\n{traceback.format_exc()}") |
|
|
|
return { |
|
"error": "Internal Server Error in Proxy", |
|
"details": error_detail |
|
} |
|
|
|
|
|
|
|
|
|
input_example = { |
|
"imageDataUrl": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA...", |
|
"candidate_labels": ["cat", "dog", "car"] |
|
} |
|
output_example_success = { |
|
"result": [{"score": 0.95, "label": "cat"}, {"score": 0.03, "label": "dog"}, {"score": 0.02, "label": "car"}] |
|
} |
|
output_example_error = { |
|
"error": "Target API Error (Status 422)", |
|
"details": "Input validation error on target server.", |
|
"request_id": "abc-123" |
|
} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Inference Proxy\nAccepts JSON input with `imageDataUrl` and `candidate_labels`, calls the target zero-shot model, and returns JSON output.") |
|
with gr.Row(): |
|
|
|
input_json = gr.JSON(label="Input Data (JSON)", value=input_example) |
|
output_json = gr.JSON(label="Output Result (JSON)") |
|
|
|
gr.Markdown(f"**Example Success Output:**\n```json\n{json.dumps(output_example_success, indent=2)}\n```") |
|
gr.Markdown(f"**Example Error Output:**\n```json\n{json.dumps(output_example_error, indent=2)}\n```") |
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_btn = gr.Button("Process (for API)", visible=False) |
|
submit_btn.click( |
|
fn=proxy_inference, |
|
inputs=input_json, |
|
outputs=output_json, |
|
api_name="predict" |
|
) |
|
|
|
|
|
|
|
demo.launch(share=False) |