from pathlib import Path import os import shutil from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from llama_cpp import Llama import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MODEL_PATH = "model/Phi-3-mini-4k-instruct-q4.gguf" # Set cache directory to /tmp which has proper permissions os.environ["HF_HOME"] = "/tmp/huggingface" if not Path(MODEL_PATH).exists(): from huggingface_hub import hf_hub_download # Ensure model directory exists with proper permissions model_dir = Path("model") model_dir.mkdir(exist_ok=True) try: # Download to cache first, then copy to model directory cached_model_path = hf_hub_download( repo_id="microsoft/Phi-3-mini-4k-instruct-gguf", filename="Phi-3-mini-4k-instruct-q4.gguf", cache_dir="/tmp/huggingface" ) # Copy from cache to our model directory shutil.copy2(cached_model_path, MODEL_PATH) print(f"Model copied to {MODEL_PATH}") except PermissionError: # Fallback: use the cached model directly print(f"Permission denied copying to {MODEL_PATH}, using cached model directly") MODEL_PATH = cached_model_path # Load the model with optimizations for free tier llm = Llama( model_path=MODEL_PATH, n_ctx=2048, # Reduced context window for speed n_threads=2, # Reduced threads to avoid resource competition n_batch=512, # Smaller batch size use_mmap=True, # Use memory mapping for efficiency use_mlock=False, # Don't lock memory (may cause issues on free tier) low_vram=True, # Optimize for low VRAM/RAM f16_kv=True, # Use 16-bit for key-value cache logits_all=False, # Don't compute logits for all tokens vocab_only=False, verbose=False # Reduce logging overhead ) class Req(BaseModel): prompt: str max_tokens: int | None = 64 # Much smaller default for speed app = FastAPI(title="Phi-3 Chat API", description="A simple chat API using Phi-3 model") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def read_root(): return {"message": "Phi-3 Chat API is running!", "model": "Phi-3-mini-4k-instruct"} @app.get("/health") def health_check(): return {"status": "healthy", "model_loaded": True} @app.post("/chat") def chat(r: Req): try: logger.info(f"Received chat request: prompt='{r.prompt[:50]}...', max_tokens={r.max_tokens}") # Validate input if not r.prompt or len(r.prompt.strip()) == 0: raise HTTPException(status_code=400, detail="Prompt cannot be empty") # Strict limits for free tier performance if r.max_tokens is None: r.max_tokens = 64 if r.max_tokens > 128: # Much stricter limit r.max_tokens = 128 if r.max_tokens < 1: r.max_tokens = 1 # Truncate prompt if too long to avoid timeout if len(r.prompt) > 500: r.prompt = r.prompt[:500] + "..." logger.info(f"Processing with max_tokens={r.max_tokens}") # Optimized generation parameters for speed out = llm( prompt=r.prompt, max_tokens=r.max_tokens, stream=False, temperature=0.3, # Lower temperature for faster, more focused responses top_p=0.7, # More focused sampling top_k=20, # Limit vocabulary for speed repeat_penalty=1.1, stop=["\n\n", "Human:", "Assistant:", "User:"], # Stop early on common patterns echo=False # Don't echo the prompt back ) response_text = out["choices"][0]["text"].strip() # Handle empty responses if not response_text: response_text = "I need more context to provide a helpful response." logger.info(f"Generated response length: {len(response_text)}") return {"answer": response_text} except HTTPException: raise except Exception as e: logger.error(f"Error in chat endpoint: {str(e)}") # Fallback response instead of error return {"answer": "I'm experiencing high load. Please try a shorter message."} @app.post("/fast-chat") def fast_chat(r: Req): """Ultra-fast endpoint with very strict limits for free tier""" try: logger.info(f"Fast chat request: {r.prompt[:30]}...") if not r.prompt or len(r.prompt.strip()) == 0: return {"answer": "Please provide a message."} # Ultra-strict limits for maximum speed max_tokens = min(r.max_tokens or 32, 32) # Max 32 tokens prompt = r.prompt[:200] # Max 200 chars out = llm( prompt=prompt, max_tokens=max_tokens, stream=False, temperature=0.1, # Very low for speed top_p=0.5, top_k=10, # Very limited vocabulary repeat_penalty=1.0, stop=["\n", ".", "!", "?"], # Stop on first sentence echo=False ) response_text = out["choices"][0]["text"].strip() if not response_text: response_text = "OK" return {"answer": response_text} except Exception as e: logger.error(f"Fast chat error: {str(e)}") return {"answer": "Quick response unavailable."}