Spaces:
Running
Running
File size: 3,130 Bytes
1314214 37becb1 f9faf91 8e98672 1314214 369b62f 37becb1 369b62f 37becb1 c696661 37becb1 369b62f d2b7fba 16db54b 369b62f 16db54b d2b7fba 1314214 37becb1 369b62f 5c2c28b 37becb1 369b62f 1314214 369b62f 37becb1 369b62f c696661 369b62f c696661 369b62f c696661 369b62f c696661 369b62f c696661 369b62f 1314214 c696661 1314214 37becb1 369b62f 37becb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import sys
from flask import Flask, request, jsonify
from huggingface_hub import InferenceClient
app = Flask(__name__)
API_KEY = (os.getenv("API_KEY") or "").strip()
# Multilingual zero-shot model (handles Hindi + English well)
ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip()
LABELS = [
"health_wellness",
"spiritual_guidance",
"generate_image",
"realtime_query",
"other_query",
]
ALLOWED = set(LABELS)
def log(msg, **kv):
line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()])
print(line, file=sys.stderr, flush=True)
# Init HF client once
client = InferenceClient(token=API_KEY) if API_KEY else None
@app.get("/")
def root():
return jsonify({"ok": True, "model": ZSL_MODEL_ID})
@app.post("/generate_text")
def generate_text():
if not API_KEY:
log("DECISION_ERR", reason="missing_api_key")
return jsonify({"error": "Missing API_KEY"}), 400
if client is None:
log("DECISION_ERR", reason="client_not_initialized")
return jsonify({"error": "Client not initialized"}), 500
data = request.get_json(silent=True) or {}
prompt = (data.get("prompt") or "").strip()
instructions = (data.get("instructions") or "").strip() # not required here
if not prompt:
log("DECISION_BAD_REQ", has_prompt=False)
return jsonify({"error": "Missing required fields"}), 400
# Fast-path: explicit image command
if prompt.startswith("/image "):
log("DECISION_FAST", token="generate_image")
return jsonify({"response": "generate_image"}), 200
try:
log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
zs = client.zero_shot_classification(
prompt,
LABELS,
model=ZSL_MODEL_ID,
hypothesis_template="This text is about {}.",
multi_label=False, # single best label
)
# Normalize shapes:
# - Newer hub often returns a dict
# - Some providers return a list[dict] (one per input)
if isinstance(zs, list):
zs = zs[0] if zs else {}
if not isinstance(zs, dict):
raise ValueError(f"Unexpected ZSL response type: {type(zs)}")
labels = zs.get("labels") or zs.get("candidate_labels") or []
scores = zs.get("scores") or []
if not labels and "label" in zs:
labels = [zs["label"]]
scores = [zs.get("score", 0.0)]
best = labels[0] if labels else "other_query"
score = float(scores[0]) if scores else 0.0
token = best if best in ALLOWED else "other_query"
log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
return jsonify({"response": token}), 200
except Exception as e:
log("DECISION_FAIL", error=str(e))
return jsonify({"response": "other_query", "error": str(e)}), 200
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
app.run(host="0.0.0.0", port=port)
|