import os import io import base64 import ctypes import threading import json import time import uuid from flask import Flask, request, jsonify, Response from flask_cors import CORS # --- Model Configuration --- HF_REPO = "litert-community/gemma-4-E2B-it-litert-lm" HF_FILE = "gemma-4-E2B-it.litertlm" _SERVER_DIR = os.path.dirname(os.path.abspath(__file__)) _DEFAULT_PATH = os.path.join(_SERVER_DIR, "models", "gemma", HF_FILE) # litert_lm links against libvulkan.so.1 even on CPU-only runs. _vk_stub = os.path.join(_SERVER_DIR, "libvulkan.so.1") if os.path.exists(_vk_stub): try: ctypes.CDLL(_vk_stub, mode=ctypes.RTLD_GLOBAL) except OSError: pass # Suppress verbose C++ logs from litert_lm os.environ.setdefault("GLOG_minloglevel", "3") MODEL_PATH = os.environ.get("GEMMA_MODEL_PATH", _DEFAULT_PATH).strip() MODEL_ID = "gemma-4-e2b" model_status = "loading" engine = None _engine_ctx = None engine_lock = threading.BoundedSemaphore(value=2) app = Flask(__name__) CORS(app) # ─── Model loading ───────────────────────────────────────────────────────────── def load_model(): global engine, model_status, _engine_ctx if not MODEL_PATH: print("[INFO] GEMMA_MODEL_PATH not set — no model loaded", flush=True) model_status = "no_model_path" return try: import litert_lm as _lm _lm.set_min_log_severity(_lm.LogSeverity.SILENT) except ImportError: print("[INFO] litert_lm not installed — no model loaded", flush=True) model_status = "no_litert_lm" return if not os.path.exists(MODEL_PATH): print(f"[WARN] Model file not found: {MODEL_PATH}", flush=True) model_status = "model_file_missing" return try: _engine_ctx = _lm.Engine( MODEL_PATH, backend=_lm.interfaces.CPU(), vision_backend=_lm.interfaces.CPU(), ) engine = _engine_ctx.__enter__() model_status = "ready" print(f"[INFO] Model ready → {MODEL_PATH}", flush=True) except Exception as e: print(f"[ERROR] Failed to load model: {e}", flush=True) model_status = "error" # ─── OpenAI Request Parsing ──────────────────────────────────────────────────── def parse_openai_messages(messages: list) -> tuple[str, bytes | None]: """Parses OpenAI formatted messages into a flat text prompt and an optional image.""" prompt_text = "" image_bytes = None for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if isinstance(content, str): prompt_text += f"{role}: {content}\n" elif isinstance(content, list): prompt_text += f"{role}:\n" for part in content: if part.get("type") == "text": prompt_text += part.get("text", "") + "\n" elif part.get("type") == "image_url": url = part.get("image_url", {}).get("url", "") if url.startswith("data:image"): try: b64_data = url.split(",", 1)[1] image_bytes = base64.b64decode(b64_data) except Exception as e: print(f"[WARN] Failed to decode base64 image: {e}") prompt_text += "assistant: " return prompt_text.strip(), image_bytes # ─── Inference Engine ────────────────────────────────────────────────────────── def _run_real_model_generator(ask: str, image_bytes: bytes | None): """Yields text chunks as they are generated by the model.""" import litert_lm # engine_lock ensures only 2 requests process at a time to prevent RAM crashes if not engine_lock.acquire(timeout=30): raise RuntimeError("Server busy. Try again shortly.") try: with engine.create_conversation() as conv: if image_bytes: msg = litert_lm.Contents.of( litert_lm.Content.ImageBytes(image_bytes), litert_lm.Content.Text(ask), ) else: msg = ask for chunk in conv.send_message_async(msg): for part in chunk.get("content", []): if part.get("type") == "text": text = part.get("text", "") if text: yield text finally: engine_lock.release() def _run_mock_generator(ask: str, has_image: bool): """Fallback generator when the model is missing/loading.""" msg = f"[MOCK] Received prompt. Vision included: {has_image}. Connect litert_lm for real output." for word in msg.split(): yield word + " " time.sleep(0.05) # ─── Routes ──────────────────────────────────────────────────────────────────── @app.route("/v1/models", methods=["GET"]) def list_models(): """OpenAI models endpoint.""" return jsonify({ "object": "list", "data": [{ "id": MODEL_ID, "object": "model", "created": int(time.time()), "owned_by": "litert-community" }] }) @app.route("/v1/chat/completions", methods=["POST"]) def chat_completions(): """OpenAI compatible chat completions endpoint.""" data = request.get_json(silent=True) or {} messages = data.get("messages", []) stream = data.get("stream", False) if not messages: return jsonify({"error": {"message": "Missing 'messages' array", "type": "invalid_request_error"}}), 400 ask, image_bytes = parse_openai_messages(messages) # Determine which generator to use if engine is None or model_status != "ready": generator = _run_mock_generator(ask, bool(image_bytes)) else: generator = _run_real_model_generator(ask, image_bytes) req_model = data.get("model", MODEL_ID) cmpl_id = f"chatcmpl-{uuid.uuid4().hex}" created_time = int(time.time()) if stream: def stream_response(): # 1. Initial chunk indicating role init_chunk = { "id": cmpl_id, "object": "chat.completion.chunk", "created": created_time, "model": req_model, "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}] } yield f"data: {json.dumps(init_chunk)}\n\n" # 2. Stream tokens try: for text_chunk in generator: chunk = { "id": cmpl_id, "object": "chat.completion.chunk", "created": created_time, "model": req_model, "choices": [{"index": 0, "delta": {"content": text_chunk}, "finish_reason": None}] } yield f"data: {json.dumps(chunk)}\n\n" except Exception as e: err_chunk = {"error": str(e)} yield f"data: {json.dumps(err_chunk)}\n\n" # 3. Final chunk indicating stop final_chunk = { "id": cmpl_id, "object": "chat.completion.chunk", "created": created_time, "model": req_model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] } yield f"data: {json.dumps(final_chunk)}\n\n" yield "data: [DONE]\n\n" return Response(stream_response(), mimetype="text/event-stream") else: try: full_text = "".join(list(generator)) response = { "id": cmpl_id, "object": "chat.completion", "created": created_time, "model": req_model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": full_text }, "finish_reason": "stop" }], "usage": { "prompt_tokens": 0, # litert_lm token counting not implemented "completion_tokens": 0, "total_tokens": 0 } } return jsonify(response) except Exception as e: return jsonify({"error": {"message": f"Model error: {e}", "type": "server_error"}}), 500 # ─── Entry ───────────────────────────────────────────────────────────────────── if __name__ == "__main__": port = int(os.environ.get("PORT", 5173)) threading.Thread(target=load_model, daemon=True).start() print(f"[INFO] Gemma OpenAI-Compatible API listening on :{port}", flush=True) app.run( host="0.0.0.0", port=port, debug=False, threaded=True, )