ndc8
Update Dockerfile and application entry point for GGUF backend; optimize memory usage in model parameters and requirements
358e717
#!/usr/bin/env python3 | |
""" | |
Working Gemma 3n GGUF Backend Service | |
Minimal FastAPI backend using only llama-cpp-python for GGUF models | |
""" | |
import os | |
import logging | |
import time | |
from contextlib import asynccontextmanager | |
from typing import List, Dict, Any, Optional | |
import uuid | |
import sys | |
import subprocess | |
import threading | |
from pathlib import Path | |
import signal # Use signal.SIGTERM for process termination | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field, field_validator | |
# Import llama-cpp-python for GGUF model support | |
try: | |
from llama_cpp import Llama | |
llama_cpp_available = True | |
except ImportError: | |
llama_cpp_available = False | |
import uvicorn | |
import sqlite3 | |
import json # For persisting job metadata | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Pydantic models for OpenAI-compatible API | |
class ChatMessage(BaseModel): | |
role: str = Field(..., description="The role of the message author") | |
content: str = Field(..., description="The content of the message") | |
def validate_role(cls, v: str) -> str: | |
if v not in ["system", "user", "assistant"]: | |
raise ValueError("Role must be one of: system, user, assistant") | |
return v | |
class ChatCompletionRequest(BaseModel): | |
model: str = Field(default="gemma-3n-e4b-it", description="The model to use for completion") | |
messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") | |
max_tokens: Optional[int] = Field(default=256, ge=1, le=1024, description="Maximum tokens to generate (reduced for memory efficiency)") | |
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature") | |
top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") | |
top_k: Optional[int] = Field(default=64, ge=1, le=100, description="Top-k sampling") | |
stream: Optional[bool] = Field(default=False, description="Whether to stream responses") | |
class ChatCompletionChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: str | |
class ChatCompletionResponse(BaseModel): | |
id: str | |
object: str = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatCompletionChoice] | |
class HealthResponse(BaseModel): | |
status: str | |
model: str | |
version: str | |
backend: str | |
from pathlib import Path | |
# Global variables for model management | |
current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF") | |
llm = None | |
def convert_messages_to_gemma_prompt(messages: List[ChatMessage]) -> str: | |
"""Convert OpenAI messages format to Gemma 3n chat format.""" | |
# Gemma 3n uses specific format with <start_of_turn> and <end_of_turn> | |
prompt_parts = ["<bos>"] | |
for message in messages: | |
role = message.role | |
content = message.content | |
if role == "system": | |
prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") | |
elif role == "user": | |
prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") | |
elif role == "assistant": | |
prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") | |
# Add the start for model response | |
prompt_parts.append("<start_of_turn>model\n") | |
return "\n".join(prompt_parts) | |
async def lifespan(app: FastAPI): | |
"""Application lifespan manager for startup and shutdown events""" | |
global llm | |
logger.info("🚀 Starting Gemma 3n GGUF Backend Service...") | |
if os.environ.get("DEMO_MODE", "").strip() not in ("", "0", "false", "False"): | |
logger.info("🧪 DEMO_MODE enabled: skipping model load") | |
llm = None | |
yield | |
logger.info("🔄 Shutting down Gemma 3n Backend Service (demo mode)...") | |
return | |
if not llama_cpp_available: | |
logger.error("❌ llama-cpp-python is not available. Please install with: pip install llama-cpp-python") | |
raise RuntimeError("llama-cpp-python not available") | |
try: | |
logger.info(f"📥 Loading Gemma 3n GGUF model from {current_model}...") | |
# Configure model parameters optimized for HF Spaces memory constraints | |
llm = Llama.from_pretrained( | |
repo_id=current_model, | |
filename="*Q4_0.gguf", # Use Q4_0 instead of Q4_K_M for lower memory usage | |
verbose=True, | |
# Memory-optimized settings for HF Spaces | |
n_ctx=2048, # Reduced context length to save memory (was 4096) | |
n_threads=2, # Fewer threads for lower memory usage (was 4) | |
n_gpu_layers=0, # Force CPU-only to avoid GPU memory issues | |
# Additional memory optimizations | |
n_batch=512, # Smaller batch size to reduce memory peaks | |
use_mmap=True, # Use memory mapping to reduce RAM usage | |
use_mlock=False, # Don't lock memory pages | |
low_vram=True, # Enable low VRAM mode for additional memory savings | |
# Chat template for Gemma 3n format | |
chat_format="gemma", # Try built-in gemma format first | |
) | |
logger.info("✅ Successfully loaded Gemma 3n GGUF model with memory optimizations") | |
except Exception as e: | |
logger.error(f"❌ Failed to initialize Gemma 3n model: {e}") | |
logger.warning("⚠️ Please download the GGUF model file locally and update the path") | |
logger.warning("⚠️ You can download from: https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF") | |
# For demo purposes, we'll continue without the model | |
logger.info("🔄 Starting service in demo mode (responses will be mocked)") | |
yield | |
logger.info("🔄 Shutting down Gemma 3n Backend Service...") | |
if llm: | |
# Clean up model resources | |
llm = None | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Gemma 3n GGUF Backend Service", | |
description="OpenAI-compatible chat completion API powered by Gemma-3n-E4B-it-GGUF", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Configure appropriately for production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def ensure_model_ready(): | |
"""Check if model is loaded and ready""" | |
# For demo mode, we'll allow the service to run even without a model | |
pass | |
def generate_response_gguf(messages: List[ChatMessage], max_tokens: int = 256, temperature: float = 1.0, top_p: float = 0.95, top_k: int = 64) -> str: | |
"""Generate response using GGUF model via llama-cpp-python (memory-optimized).""" | |
if llm is None: | |
# Demo mode response | |
return "🤖 Demo mode: Gemma 3n model not loaded. This would be a real response from the Gemma 3n model. Please download the GGUF model from https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF" | |
# Limit max_tokens for memory efficiency on HF Spaces | |
max_tokens = min(max_tokens, 512) # Cap at 512 tokens max | |
try: | |
# Use the chat completion method if available | |
if hasattr(llm, 'create_chat_completion'): | |
# Convert to dict format for llama-cpp-python | |
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] | |
response = llm.create_chat_completion( | |
messages=messages_dict, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
stop=["<end_of_turn>", "<eos>", "</s>"] # Gemma 3n stop tokens | |
) | |
return response['choices'][0]['message']['content'].strip() | |
else: | |
# Fallback to direct prompt completion | |
prompt = convert_messages_to_gemma_prompt(messages) | |
response = llm( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
stop=["<end_of_turn>", "<eos>", "</s>"], | |
echo=False | |
) | |
return response['choices'][0]['text'].strip() | |
except Exception as e: | |
logger.error(f"GGUF generation failed: {e}") | |
return "I apologize, but I'm having trouble generating a response right now. Please try again." | |
async def root() -> Dict[str, Any]: | |
"""Root endpoint with service information""" | |
return { | |
"message": "Gemma 3n GGUF Backend Service is running!", | |
"model": current_model, | |
"version": "1.0.0", | |
"backend": "llama-cpp-python", | |
"model_loaded": llm is not None, | |
"endpoints": { | |
"health": "/health", | |
"chat_completions": "/v1/chat/completions" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
return HealthResponse( | |
status="healthy" if (llm is not None) else "demo_mode", | |
model=current_model, | |
version="1.0.0", | |
backend="llama-cpp-python" | |
) | |
async def create_chat_completion( | |
request: ChatCompletionRequest | |
) -> ChatCompletionResponse: | |
"""Create a chat completion (OpenAI-compatible) using Gemma 3n GGUF""" | |
try: | |
ensure_model_ready() | |
if not request.messages: | |
raise HTTPException(status_code=400, detail="Messages cannot be empty") | |
logger.info(f"Generating Gemma 3n response for {len(request.messages)} messages") | |
response_text = generate_response_gguf( | |
request.messages, | |
request.max_tokens or 512, | |
request.temperature or 1.0, | |
request.top_p or 0.95, | |
request.top_k or 64 | |
) | |
response_text = response_text.strip() if response_text else "No response generated." | |
return ChatCompletionResponse( | |
id=f"chatcmpl-{int(time.time())}", | |
created=int(time.time()), | |
model=request.model, | |
choices=[ChatCompletionChoice( | |
index=0, | |
message=ChatMessage(role="assistant", content=response_text), | |
finish_reason="stop" | |
)] | |
) | |
except Exception as e: | |
logger.error(f"Error in chat completion: {e}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
# ----------------------------- | |
# Training Job Management (Unsloth) | |
# ----------------------------- | |
# Persistent job store: in-memory dict backed by SQLite | |
TRAIN_JOBS: Dict[str, Dict[str, Any]] = {} | |
# Initialize SQLite DB for job persistence | |
DB_PATH = Path(os.environ.get("JOB_DB_PATH", "./jobs.db")) | |
conn = sqlite3.connect(str(DB_PATH), check_same_thread=False) | |
cursor = conn.cursor() | |
cursor.execute( | |
""" | |
CREATE TABLE IF NOT EXISTS jobs ( | |
job_id TEXT PRIMARY KEY, | |
data TEXT NOT NULL | |
) | |
""" | |
) | |
conn.commit() | |
def load_jobs() -> None: | |
cursor.execute("SELECT job_id, data FROM jobs") | |
for job_id, data in cursor.fetchall(): | |
TRAIN_JOBS[job_id] = json.loads(data) | |
def save_job(job_id: str) -> None: | |
cursor.execute( | |
"INSERT OR REPLACE INTO jobs (job_id, data) VALUES (?, ?)", | |
(job_id, json.dumps(TRAIN_JOBS[job_id])) | |
) | |
conn.commit() | |
# Load existing jobs on startup | |
load_jobs() | |
TRAIN_DIR = Path(os.environ.get("TRAIN_DIR", "./training_runs")).resolve() | |
TRAIN_DIR.mkdir(parents=True, exist_ok=True) | |
# Maximum concurrent training jobs | |
MAX_CONCURRENT_JOBS = int(os.environ.get("MAX_CONCURRENT_JOBS", "5")) | |
def _start_training_subprocess(job_id: str, args: Dict[str, Any]) -> subprocess.Popen[Any]: | |
"""Spawn a subprocess to run the Unsloth fine-tuning script.""" | |
logs_dir = TRAIN_DIR / job_id | |
logs_dir.mkdir(parents=True, exist_ok=True) | |
log_file = open(logs_dir / "train.log", "w", encoding="utf-8") | |
# Store log file handle to close later | |
TRAIN_JOBS.setdefault(job_id, {})["log_file"] = log_file | |
save_job(job_id) | |
# Build absolute script path to avoid module/package resolution issues | |
script_path = (Path(__file__).parent / "training" / "train_gemma_unsloth.py").resolve() | |
# Verify training script exists | |
if not script_path.exists(): | |
logger.error(f"Training script not found at {script_path}") | |
raise HTTPException(status_code=500, detail=f"Training script not found at {script_path}") | |
python_exec = sys.executable | |
cmd = [ | |
python_exec, | |
str(script_path), | |
"--job-id", job_id, | |
"--output-dir", str(logs_dir), | |
] | |
# Optional user-specified args | |
def _extend(k: str, v: Any): | |
if v is None: | |
return | |
if isinstance(v, bool): | |
cmd.extend([f"--{k}"] if v else []) | |
else: | |
cmd.extend([f"--{k}", str(v)]) | |
_extend("dataset", args.get("dataset")) | |
_extend("text-field", args.get("text_field")) | |
_extend("prompt-field", args.get("prompt_field")) | |
_extend("response-field", args.get("response_field")) | |
_extend("max-steps", args.get("max_steps")) | |
_extend("epochs", args.get("epochs")) | |
_extend("lr", args.get("lr")) | |
_extend("batch-size", args.get("batch_size")) | |
_extend("gradient-accumulation", args.get("gradient_accumulation")) | |
_extend("lora-r", args.get("lora_r")) | |
_extend("lora-alpha", args.get("lora_alpha")) | |
_extend("cutoff-len", args.get("cutoff_len")) | |
_extend("model-id", args.get("model_id")) | |
_extend("use-bf16", args.get("use_bf16")) | |
_extend("use-fp16", args.get("use_fp16")) | |
_extend("seed", args.get("seed")) | |
_extend("dry-run", args.get("dry_run")) | |
logger.info(f"🧵 Starting training subprocess for job {job_id}: {' '.join(cmd)}") | |
logger.info(f"🐍 Using interpreter: {python_exec}") | |
proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, cwd=str(Path(__file__).parent)) | |
return proc | |
def _watch_process(job_id: str, proc: subprocess.Popen[Any]): | |
"""Monitor a training process and update job state on exit.""" | |
return_code = proc.wait() | |
status = "completed" if return_code == 0 else "failed" | |
TRAIN_JOBS[job_id]["status"] = status | |
TRAIN_JOBS[job_id]["return_code"] = return_code | |
TRAIN_JOBS[job_id]["ended_at"] = int(time.time()) | |
# Persist updated job status | |
save_job(job_id) | |
# Close the log file handle to prevent resource leaks | |
log_file = TRAIN_JOBS[job_id].get("log_file") | |
if log_file: | |
try: | |
log_file.close() | |
except Exception as close_err: | |
logger.warning(f"Failed to close log file for job {job_id}: {close_err}") | |
logger.info(f"🏁 Training job {job_id} finished with status={status}, code={return_code}") | |
class StartTrainingRequest(BaseModel): | |
dataset: str = Field(..., description="HF dataset name or path to local JSONL/JSON file") | |
model_id: Optional[str] = Field(default="unsloth/gemma-3n-E4B-it", description="Base model for training (HF Transformers format)") | |
text_field: Optional[str] = Field(default=None, description="Single text field name (SFT)") | |
prompt_field: Optional[str] = Field(default=None, description="Prompt/instruction field (chat data)") | |
response_field: Optional[str] = Field(default=None, description="Response/output field (chat data)") | |
max_steps: Optional[int] = Field(default=None) | |
epochs: Optional[int] = Field(default=1) | |
lr: Optional[float] = Field(default=2e-4) | |
batch_size: Optional[int] = Field(default=1) | |
gradient_accumulation: Optional[int] = Field(default=8) | |
lora_r: Optional[int] = Field(default=16) | |
lora_alpha: Optional[int] = Field(default=32) | |
cutoff_len: Optional[int] = Field(default=4096) | |
use_bf16: Optional[bool] = Field(default=True) | |
use_fp16: Optional[bool] = Field(default=False) | |
seed: Optional[int] = Field(default=42) | |
dry_run: Optional[bool] = Field(default=False, description="Write DONE and exit without running (for CI/macOS)") | |
class StartTrainingResponse(BaseModel): | |
job_id: str | |
status: str | |
output_dir: str | |
class TrainStatusResponse(BaseModel): | |
job_id: str | |
status: str | |
created_at: int | |
started_at: Optional[int] = None | |
ended_at: Optional[int] = None | |
output_dir: Optional[str] = None | |
return_code: Optional[int] = None | |
def start_training(req: StartTrainingRequest): | |
"""Start a background Unsloth fine-tuning job. Returns a job_id to poll.""" | |
# Enforce maximum concurrent training jobs | |
running_jobs = sum(1 for job in TRAIN_JOBS.values() if job.get("status") == "running") | |
if running_jobs >= MAX_CONCURRENT_JOBS: | |
raise HTTPException( | |
status_code=429, | |
detail=f"Maximum concurrent training jobs reached ({MAX_CONCURRENT_JOBS}). Try again later." | |
) | |
job_id = uuid.uuid4().hex[:12] | |
now = int(time.time()) | |
output_dir = str((TRAIN_DIR / job_id).resolve()) | |
TRAIN_JOBS[job_id] = { | |
"status": "starting", | |
"created_at": now, | |
"started_at": now, | |
"args": req.model_dump(), | |
"output_dir": output_dir, | |
} | |
save_job(job_id) | |
try: | |
proc = _start_training_subprocess(job_id, req.model_dump()) | |
TRAIN_JOBS[job_id]["status"] = "running" | |
TRAIN_JOBS[job_id]["pid"] = proc.pid | |
save_job(job_id) | |
watcher = threading.Thread(target=_watch_process, args=(job_id, proc), daemon=True) | |
watcher.start() | |
return StartTrainingResponse(job_id=job_id, status="running", output_dir=output_dir) | |
except Exception as e: | |
logger.exception("Failed to start training job") | |
TRAIN_JOBS[job_id]["status"] = "failed_to_start" | |
save_job(job_id) | |
raise HTTPException(status_code=500, detail=f"Failed to start training: {e}") | |
def train_status(job_id: str): | |
job = TRAIN_JOBS.get(job_id) | |
if not job: | |
raise HTTPException(status_code=404, detail="Job not found") | |
return TrainStatusResponse( | |
job_id=job_id, | |
status=job.get("status", "unknown"), | |
created_at=job.get("created_at", 0), | |
started_at=job.get("started_at"), | |
ended_at=job.get("ended_at"), | |
output_dir=job.get("output_dir"), | |
return_code=job.get("return_code"), | |
) | |
def train_logs( | |
job_id: str, | |
tail: int = Query(200, ge=0, le=1000, description="Number of lines to tail, between 0 and 1000"), | |
): | |
job = TRAIN_JOBS.get(job_id) | |
if not job: | |
raise HTTPException(status_code=404, detail="Job not found") | |
log_path = Path(job["output_dir"]) / "train.log" | |
if not log_path.exists(): | |
return {"job_id": job_id, "logs": "(no logs yet)"} | |
try: | |
with open(log_path, "r", encoding="utf-8", errors="ignore") as f: | |
lines = f.readlines()[-tail:] | |
return {"job_id": job_id, "logs": "".join(lines)} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to read logs: {e}") | |
def train_stop(job_id: str): | |
job = TRAIN_JOBS.get(job_id) | |
if not job: | |
raise HTTPException(status_code=404, detail="Job not found") | |
pid = job.get("pid") | |
if not pid: | |
raise HTTPException(status_code=400, detail="Job does not have an active PID") | |
try: | |
os.kill(pid, signal.SIGTERM) | |
except ProcessLookupError: | |
logger.warning( | |
f"Process {pid} for job {job_id} not found; may have exited already" | |
) | |
job["status"] = "stopping_failed" | |
save_job(job_id) | |
return {"job_id": job_id, "status": job["status"]} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to stop job: {e}") | |
else: | |
job["status"] = "stopping" | |
save_job(job_id) | |
return {"job_id": job_id, "status": "stopping"} | |
# Main entry point | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |