| from __future__ import annotations |
|
|
| import os, json, time, uuid, asyncio, logging |
| from typing import Any, AsyncGenerator |
| from contextlib import asynccontextmanager |
|
|
| from dotenv import load_dotenv |
| from fastapi import FastAPI, HTTPException, Request, Depends |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from pydantic import BaseModel |
| from gradio_client import Client |
|
|
| load_dotenv() |
|
|
| |
| |
| |
| API_KEY = os.getenv("API_KEY", "") |
| HF_SPACE_URL = os.getenv("HF_SPACE_URL", "") |
| MODEL_ID = os.getenv("MODEL_ID", "") |
| DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6")) |
| DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95")) |
| DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1024")) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| |
| |
| |
| _client: Client | None = None |
|
|
| async def get_client() -> Client: |
| global _client |
| if _client is None: |
| log.info("Connecting to %s", HF_SPACE_URL) |
| _client = await asyncio.to_thread(Client, HF_SPACE_URL) |
| log.info("Connected.") |
| return _client |
|
|
| |
| |
| |
|
|
| class Message(BaseModel): |
| role: str |
| content: str | list[dict] = "" |
| name: str | None = None |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str = MODEL_ID |
| messages: list[Message] |
| temperature: float = DEFAULT_TEMP |
| top_p: float = DEFAULT_TOP_P |
| max_tokens: int = DEFAULT_TOKENS |
| stream: bool = False |
| frequency_penalty: float = 0 |
| presence_penalty: float = 0 |
| stop: str | list[str] | None = None |
| seed: int | None = None |
| user: str | None = None |
|
|
| |
| |
| |
|
|
| async def verify_key(request: Request) -> None: |
| if not API_KEY: |
| return |
| auth = request.headers.get("Authorization", "") |
| if not auth.startswith("Bearer ") or auth[7:] != API_KEY: |
| raise HTTPException(status_code=401, detail="Invalid or missing API key") |
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| log.info("Starting up - connecting to Gradio client...") |
| await get_client() |
| log.info("Startup complete.") |
| yield |
| |
| log.info("Shutting down.") |
|
|
| |
| |
| |
|
|
| app = FastAPI( |
| title="Falcon H1R API", |
| version="3.1.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| |
|
|
| def _content_str(m: Message) -> str: |
| if isinstance(m.content, str): |
| return m.content |
| return "".join(p.get("text", "") for p in m.content if p.get("type") == "text") |
|
|
| def _build_prompt(messages: list[Message]) -> str: |
| """Flatten messages into a single prompt string.""" |
| system, parts = [], [] |
| for m in messages: |
| c = _content_str(m) |
| if m.role == "system": system.append(c) |
| elif m.role == "user": parts.append(c) |
| elif m.role == "assistant": parts.append(f"[ASSISTANT]\n{c}") |
| prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n" if system else "" |
| return prefix + "\n".join(parts) |
|
|
| def _extract_text(result) -> str: |
| """ |
| HTML chatbot does: |
| const last = res.data[5].value.at(-1); |
| const text = Array.isArray(last.content) |
| ? last.content.filter(p => p.type === 'text').map(p => p.content.trim()).join('') |
| : last.content; |
| """ |
| try: |
| |
| chatbot_data = result.data[5] |
| |
| conversation = chatbot_data["value"] |
| |
| last = conversation[-1] |
| content = last["content"] |
| |
| if isinstance(content, list): |
| |
| return "".join( |
| p["content"].strip() |
| for p in content |
| if p.get("type") == "text" |
| ) |
| return str(content) |
| except Exception as e: |
| log.error("_extract_text failed: %s | raw data: %s", e, result.data) |
| raise ValueError(f"Failed to extract text: {e}") from e |
|
|
| async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str: |
| """ |
| Exact replica of HTML submit() function: |
| 1. client.predict('/add_message', { input_value: msg, settings_form_value: PARAMS }) |
| 2. Extract res.data[5].value.at(-1).content |
| """ |
| client = await get_client() |
| |
| settings = { |
| "model": req.model, |
| "temperature": req.temperature, |
| "max_new_tokens": req.max_tokens, |
| "top_p": req.top_p, |
| } |
| |
| |
| await asyncio.to_thread( |
| client.predict, |
| api_name="/new_chat" |
| ) |
| |
| |
| result = await asyncio.to_thread( |
| client.predict, |
| input_value=prompt, |
| settings_form_value=settings, |
| api_name="/add_message" |
| ) |
| |
| return _extract_text(result) |
|
|
| def _make_response(text: str, req: ChatCompletionRequest) -> dict: |
| pt = sum(len(_content_str(m)) for m in req.messages) // 4 |
| ct = len(text) // 4 |
| return { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": req.model, |
| "system_fingerprint": f"fp_{uuid.uuid4().hex[:8]}", |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": text, |
| "tool_calls": None, |
| "function_call": None, |
| }, |
| "finish_reason": "stop", |
| "logprobs": None, |
| }], |
| "usage": { |
| "prompt_tokens": pt, |
| "completion_tokens": ct, |
| "total_tokens": pt + ct, |
| }, |
| } |
|
|
| async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]: |
| """Simulate streaming by chunking the full response.""" |
| cid = f"chatcmpl-{uuid.uuid4().hex}" |
| created = int(time.time()) |
| |
| |
| for i in range(0, len(text), 6): |
| chunk = { |
| "id": cid, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": req.model, |
| "choices": [{ |
| "index": 0, |
| "delta": {"role": "assistant", "content": text[i:i+6]}, |
| "finish_reason": None, |
| }], |
| } |
| yield f"data: {json.dumps(chunk)}\n\n" |
| await asyncio.sleep(0.01) |
| |
| |
| pt = sum(len(_content_str(m)) for m in req.messages) // 4 |
| ct = len(text) // 4 |
| final = { |
| "id": cid, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": req.model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], |
| "usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct}, |
| } |
| yield f"data: {json.dumps(final)}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
| |
| |
| |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "service": "Falcon H1R OpenAI-compatible API", |
| "version": "3.1.0", |
| "endpoints": { |
| "health": "/health", |
| "models": "/v1/models", |
| "chat": "/v1/chat/completions", |
| }, |
| } |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL} |
|
|
| @app.get("/v1/models") |
| async def list_models(_: None = Depends(verify_key)): |
| return {"object": "list", "data": [{ |
| "id": MODEL_ID, |
| "object": "model", |
| "created": 1710000000, |
| "owned_by": "tiiuae", |
| }]} |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)): |
| prompt = _build_prompt(req.messages) |
| log.info("Request | model=%s temp=%.2f tokens=%d stream=%s", |
| req.model, req.temperature, req.max_tokens, req.stream) |
| |
| try: |
| text = await _call_falcon(prompt, req) |
| except Exception as exc: |
| log.exception("Falcon call failed") |
| raise HTTPException(status_code=502, detail=f"Upstream error: {exc}") from exc |
| |
| if req.stream: |
| return StreamingResponse( |
| _stream_sse(text, req), |
| media_type="text/event-stream", |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| ) |
| |
| return JSONResponse(content=_make_response(text, req)) |