Spaces:
Running
Running
| import os | |
| import io | |
| import base64 | |
| import ctypes | |
| import threading | |
| import json | |
| import time | |
| import uuid | |
| from flask import Flask, request, jsonify, Response | |
| from flask_cors import CORS | |
| # --- Model Configuration --- | |
| HF_REPO = "litert-community/gemma-4-E2B-it-litert-lm" | |
| HF_FILE = "gemma-4-E2B-it.litertlm" | |
| _SERVER_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _DEFAULT_PATH = os.path.join(_SERVER_DIR, "models", "gemma", HF_FILE) | |
| # litert_lm links against libvulkan.so.1 even on CPU-only runs. | |
| _vk_stub = os.path.join(_SERVER_DIR, "libvulkan.so.1") | |
| if os.path.exists(_vk_stub): | |
| try: | |
| ctypes.CDLL(_vk_stub, mode=ctypes.RTLD_GLOBAL) | |
| except OSError: | |
| pass | |
| # Suppress verbose C++ logs from litert_lm | |
| os.environ.setdefault("GLOG_minloglevel", "3") | |
| MODEL_PATH = os.environ.get("GEMMA_MODEL_PATH", _DEFAULT_PATH).strip() | |
| MODEL_ID = "gemma-4-e2b" | |
| model_status = "loading" | |
| engine = None | |
| _engine_ctx = None | |
| engine_lock = threading.BoundedSemaphore(value=2) | |
| app = Flask(__name__) | |
| CORS(app) | |
| # βββ Model loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(): | |
| global engine, model_status, _engine_ctx | |
| if not MODEL_PATH: | |
| print("[INFO] GEMMA_MODEL_PATH not set β no model loaded", flush=True) | |
| model_status = "no_model_path" | |
| return | |
| try: | |
| import litert_lm as _lm | |
| _lm.set_min_log_severity(_lm.LogSeverity.SILENT) | |
| except ImportError: | |
| print("[INFO] litert_lm not installed β no model loaded", flush=True) | |
| model_status = "no_litert_lm" | |
| return | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"[WARN] Model file not found: {MODEL_PATH}", flush=True) | |
| model_status = "model_file_missing" | |
| return | |
| try: | |
| _engine_ctx = _lm.Engine( | |
| MODEL_PATH, | |
| backend=_lm.interfaces.CPU(), | |
| vision_backend=_lm.interfaces.CPU(), | |
| ) | |
| engine = _engine_ctx.__enter__() | |
| model_status = "ready" | |
| print(f"[INFO] Model ready β {MODEL_PATH}", flush=True) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to load model: {e}", flush=True) | |
| model_status = "error" | |
| # βββ OpenAI Request Parsing ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_openai_messages(messages: list) -> tuple[str, bytes | None]: | |
| """Parses OpenAI formatted messages into a flat text prompt and an optional image.""" | |
| prompt_text = "" | |
| image_bytes = None | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if isinstance(content, str): | |
| prompt_text += f"{role}: {content}\n" | |
| elif isinstance(content, list): | |
| prompt_text += f"{role}:\n" | |
| for part in content: | |
| if part.get("type") == "text": | |
| prompt_text += part.get("text", "") + "\n" | |
| elif part.get("type") == "image_url": | |
| url = part.get("image_url", {}).get("url", "") | |
| if url.startswith("data:image"): | |
| try: | |
| b64_data = url.split(",", 1)[1] | |
| image_bytes = base64.b64decode(b64_data) | |
| except Exception as e: | |
| print(f"[WARN] Failed to decode base64 image: {e}") | |
| prompt_text += "assistant: " | |
| return prompt_text.strip(), image_bytes | |
| # βββ Inference Engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _run_real_model_generator(ask: str, image_bytes: bytes | None): | |
| """Yields text chunks as they are generated by the model.""" | |
| import litert_lm | |
| # engine_lock ensures only 2 requests process at a time to prevent RAM crashes | |
| if not engine_lock.acquire(timeout=30): | |
| raise RuntimeError("Server busy. Try again shortly.") | |
| try: | |
| with engine.create_conversation() as conv: | |
| if image_bytes: | |
| msg = litert_lm.Contents.of( | |
| litert_lm.Content.ImageBytes(image_bytes), | |
| litert_lm.Content.Text(ask), | |
| ) | |
| else: | |
| msg = ask | |
| for chunk in conv.send_message_async(msg): | |
| for part in chunk.get("content", []): | |
| if part.get("type") == "text": | |
| text = part.get("text", "") | |
| if text: | |
| yield text | |
| finally: | |
| engine_lock.release() | |
| def _run_mock_generator(ask: str, has_image: bool): | |
| """Fallback generator when the model is missing/loading.""" | |
| msg = f"[MOCK] Received prompt. Vision included: {has_image}. Connect litert_lm for real output." | |
| for word in msg.split(): | |
| yield word + " " | |
| time.sleep(0.05) | |
| # βββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def list_models(): | |
| """OpenAI models endpoint.""" | |
| return jsonify({ | |
| "object": "list", | |
| "data": [{ | |
| "id": MODEL_ID, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "litert-community" | |
| }] | |
| }) | |
| def chat_completions(): | |
| """OpenAI compatible chat completions endpoint.""" | |
| data = request.get_json(silent=True) or {} | |
| messages = data.get("messages", []) | |
| stream = data.get("stream", False) | |
| if not messages: | |
| return jsonify({"error": {"message": "Missing 'messages' array", "type": "invalid_request_error"}}), 400 | |
| ask, image_bytes = parse_openai_messages(messages) | |
| # Determine which generator to use | |
| if engine is None or model_status != "ready": | |
| generator = _run_mock_generator(ask, bool(image_bytes)) | |
| else: | |
| generator = _run_real_model_generator(ask, image_bytes) | |
| req_model = data.get("model", MODEL_ID) | |
| cmpl_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| created_time = int(time.time()) | |
| if stream: | |
| def stream_response(): | |
| # 1. Initial chunk indicating role | |
| init_chunk = { | |
| "id": cmpl_id, "object": "chat.completion.chunk", "created": created_time, "model": req_model, | |
| "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}] | |
| } | |
| yield f"data: {json.dumps(init_chunk)}\n\n" | |
| # 2. Stream tokens | |
| try: | |
| for text_chunk in generator: | |
| chunk = { | |
| "id": cmpl_id, "object": "chat.completion.chunk", "created": created_time, "model": req_model, | |
| "choices": [{"index": 0, "delta": {"content": text_chunk}, "finish_reason": None}] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| except Exception as e: | |
| err_chunk = {"error": str(e)} | |
| yield f"data: {json.dumps(err_chunk)}\n\n" | |
| # 3. Final chunk indicating stop | |
| final_chunk = { | |
| "id": cmpl_id, "object": "chat.completion.chunk", "created": created_time, "model": req_model, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return Response(stream_response(), mimetype="text/event-stream") | |
| else: | |
| try: | |
| full_text = "".join(list(generator)) | |
| response = { | |
| "id": cmpl_id, | |
| "object": "chat.completion", | |
| "created": created_time, | |
| "model": req_model, | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": full_text | |
| }, | |
| "finish_reason": "stop" | |
| }], | |
| "usage": { | |
| "prompt_tokens": 0, # litert_lm token counting not implemented | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| } | |
| } | |
| return jsonify(response) | |
| except Exception as e: | |
| return jsonify({"error": {"message": f"Model error: {e}", "type": "server_error"}}), 500 | |
| # βββ Entry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 5173)) | |
| threading.Thread(target=load_model, daemon=True).start() | |
| print(f"[INFO] Gemma OpenAI-Compatible API listening on :{port}", flush=True) | |
| app.run( | |
| host="0.0.0.0", | |
| port=port, | |
| debug=False, | |
| threaded=True, | |
| ) |