File size: 3,130 Bytes
1314214
37becb1
f9faf91
8e98672
1314214
 
 
369b62f
 
 
37becb1
369b62f
 
 
 
 
 
 
 
37becb1
 
c696661
 
 
37becb1
369b62f
 
d2b7fba
16db54b
 
369b62f
16db54b
d2b7fba
1314214
37becb1
 
 
369b62f
 
 
5c2c28b
37becb1
 
369b62f
 
 
 
1314214
 
369b62f
 
 
 
 
37becb1
369b62f
 
 
 
 
 
c696661
369b62f
c696661
 
 
 
 
 
 
 
 
 
 
369b62f
c696661
 
 
 
369b62f
 
 
c696661
369b62f
 
c696661
369b62f
 
 
1314214
c696661
1314214
37becb1
369b62f
37becb1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import sys
from flask import Flask, request, jsonify
from huggingface_hub import InferenceClient

app = Flask(__name__)

API_KEY = (os.getenv("API_KEY") or "").strip()
# Multilingual zero-shot model (handles Hindi + English well)
ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip()

LABELS = [
    "health_wellness",
    "spiritual_guidance",
    "generate_image",
    "realtime_query",
    "other_query",
]
ALLOWED = set(LABELS)

def log(msg, **kv):
    line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()])
    print(line, file=sys.stderr, flush=True)


# Init HF client once
client = InferenceClient(token=API_KEY) if API_KEY else None

@app.get("/")
def root():
    return jsonify({"ok": True, "model": ZSL_MODEL_ID})

@app.post("/generate_text")
def generate_text():
    if not API_KEY:
        log("DECISION_ERR", reason="missing_api_key")
        return jsonify({"error": "Missing API_KEY"}), 400
    if client is None:
        log("DECISION_ERR", reason="client_not_initialized")
        return jsonify({"error": "Client not initialized"}), 500

    data = request.get_json(silent=True) or {}
    prompt = (data.get("prompt") or "").strip()
    instructions = (data.get("instructions") or "").strip()  # not required here

    if not prompt:
        log("DECISION_BAD_REQ", has_prompt=False)
        return jsonify({"error": "Missing required fields"}), 400

    # Fast-path: explicit image command
    if prompt.startswith("/image "):
        log("DECISION_FAST", token="generate_image")
        return jsonify({"response": "generate_image"}), 200

    try:
        log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
        zs = client.zero_shot_classification(
            prompt,
            LABELS,
            model=ZSL_MODEL_ID,
            hypothesis_template="This text is about {}.",
            multi_label=False,  # single best label
        )
    
        # Normalize shapes:
        # - Newer hub often returns a dict
        # - Some providers return a list[dict] (one per input)
        if isinstance(zs, list):
            zs = zs[0] if zs else {}
    
        if not isinstance(zs, dict):
            raise ValueError(f"Unexpected ZSL response type: {type(zs)}")
    
        labels = zs.get("labels") or zs.get("candidate_labels") or []
        scores = zs.get("scores") or []
        if not labels and "label" in zs:
            labels = [zs["label"]]
            scores = [zs.get("score", 0.0)]
    
        best = labels[0] if labels else "other_query"
        score = float(scores[0]) if scores else 0.0
        token = best if best in ALLOWED else "other_query"
    
        log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
        return jsonify({"response": token}), 200
    
    except Exception as e:
        log("DECISION_FAIL", error=str(e))
        return jsonify({"response": "other_query", "error": str(e)}), 200


if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860))
    log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
    app.run(host="0.0.0.0", port=port)