Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration | |
| # ========================= | |
| # Config | |
| # ========================= | |
| MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-4b-it") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12")) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # لو عايز تغير الانتنـتس من غير تعديل الكود: | |
| # مثال: | |
| # INTENTS="greeting,pricing,complaint,booking,follow_up,other" | |
| INTENTS_ENV = os.getenv( | |
| "INTENTS", | |
| "same_path,change_path,greeting,pricing,booking,complaint,follow_up,other" | |
| ) | |
| ALLOWED_INTENTS = [x.strip() for x in INTENTS_ENV.split(",") if x.strip()] | |
| model = None | |
| processor = None | |
| # ========================= | |
| # Schemas | |
| # ========================= | |
| class IntentRequest(BaseModel): | |
| message: str | |
| intents: Optional[List[str]] = None | |
| system_prompt: Optional[str] = None | |
| class IntentResponse(BaseModel): | |
| intent: str | |
| raw_output: str | |
| model: str | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def normalize_intent(text: str, allowed_intents: List[str]) -> str: | |
| cleaned = text.strip().lower() | |
| # شيل أي markdown/code fences أو علامات زيادة | |
| cleaned = cleaned.replace("```", "").replace("`", "").strip() | |
| # لو الموديل رجّع جملة فيها intent ضمن النص | |
| for intent in allowed_intents: | |
| if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned): | |
| return intent | |
| # fallback | |
| return "other" | |
| def build_prompt(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> List[dict]: | |
| intent_list = ", ".join(allowed_intents) | |
| system_text = custom_system_prompt or ( | |
| "You are an intent classifier.\n" | |
| f"Choose exactly one intent from this list: {intent_list}.\n" | |
| "Return only the intent label, with no explanation, no punctuation, and no extra words." | |
| ) | |
| return [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_text}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": user_message}] | |
| } | |
| ] | |
| def run_intent_classification(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> tuple[str, str]: | |
| global model, processor | |
| messages = build_prompt(user_message, allowed_intents, custom_system_prompt) | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| # CPU inference | |
| with torch.inference_mode(): | |
| generation = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| temperature=None, | |
| top_p=None, | |
| ) | |
| input_len = inputs["input_ids"].shape[-1] | |
| generated_tokens = generation[0][input_len:] | |
| decoded = processor.decode(generated_tokens, skip_special_tokens=True).strip() | |
| final_intent = normalize_intent(decoded, allowed_intents) | |
| return final_intent, decoded | |
| # ========================= | |
| # Lifespan | |
| # ========================= | |
| async def lifespan(app: FastAPI): | |
| global model, processor | |
| print(f"[startup] Loading model: {MODEL_ID}") | |
| if not HF_TOKEN: | |
| raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.") | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN | |
| ) | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN, | |
| device_map="cpu" | |
| ).eval() | |
| print("[startup] Model loaded successfully.") | |
| yield | |
| print("[shutdown] App is shutting down.") | |
| app = FastAPI( | |
| title="Gemma Intent Classifier API", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # ========================= | |
| # Routes | |
| # ========================= | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "message": "Gemma Intent Classifier API is running." | |
| } | |
| def health(): | |
| return { | |
| "status": "healthy", | |
| "model": MODEL_ID | |
| } | |
| def classify_intent(payload: IntentRequest): | |
| if not payload.message or not payload.message.strip(): | |
| raise HTTPException(status_code=400, detail="message is required") | |
| allowed_intents = payload.intents if payload.intents else ALLOWED_INTENTS | |
| if not allowed_intents: | |
| raise HTTPException(status_code=400, detail="No intents provided") | |
| try: | |
| intent, raw_output = run_intent_classification( | |
| user_message=payload.message.strip(), | |
| allowed_intents=allowed_intents, | |
| custom_system_prompt=payload.system_prompt | |
| ) | |
| print("========== REQUEST ==========") | |
| print(f"message: {payload.message}") | |
| print(f"allowed_intents: {allowed_intents}") | |
| print("========== RESPONSE =========") | |
| print(f"raw_output: {raw_output}") | |
| print(f"intent: {intent}") | |
| print("================================") | |
| return IntentResponse( | |
| intent=intent, | |
| raw_output=raw_output, | |
| model=MODEL_ID | |
| ) | |
| except Exception as e: | |
| print(f"[error] {repr(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) |