Spaces:
Paused
Paused
File size: 6,794 Bytes
d828ce4 10d4b3b d828ce4 02ef8c7 d828ce4 10d4b3b 840a4e4 10d4b3b 840a4e4 d828ce4 840a4e4 10d4b3b 840a4e4 d828ce4 082c85f 1cdb3dd 082c85f eda2ff2 d828ce4 082c85f d828ce4 082c85f d828ce4 082c85f d828ce4 082c85f d828ce4 061164d d828ce4 10d4b3b d828ce4 10d4b3b d828ce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Union
import torch
import logging
from pathlib import Path
from litgpt.api import LLM
import os
import uvicorn
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI with simplified configuration
app = FastAPI(
title="LLM Engine Service",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variable to store the LLM instance
llm_instance = None
class InitializeRequest(BaseModel):
"""
Configuration for model initialization including model path
"""
mode: str = "cpu"
precision: Optional[str] = None
quantize: Optional[str] = None
gpu_count: Union[str, int] = "auto"
model_path: str
class GenerateRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
temperature: float = 1.0
top_k: Optional[int] = None
top_p: float = 1.0
return_as_token_ids: bool = False
stream: bool = False
@app.get("/")
async def root():
"""Root endpoint to verify service is running"""
return {
"status": "running",
"service": "LLM Engine",
"endpoints": {
"initialize": "/initialize",
"generate": "/generate",
"health": "/health"
}
}
@app.post("/initialize")
async def initialize_model(request: InitializeRequest):
"""
Initialize the LLM model with specified configuration.
"""
global llm_instance
try:
# Get the project root directory (where main.py is located)
project_root = Path(__file__).parent
checkpoints_dir = project_root / "checkpoints"
logger.info(f"Checkpoint dir is: {checkpoints_dir}")
# For LitGPT downloaded models, path includes organization
if "/" in request.model_path:
# e.g., "mistralai/Mistral-7B-Instruct-v0.3"
org, model_name = request.model_path.split("/")
model_path = str(checkpoints_dir / org / model_name)
else:
# Fallback for direct model paths
model_path = str(checkpoints_dir / request.model_path)
logger.info(f"Using model path: {model_path}")
# Load the model
llm_instance = LLM.load(
model=model_path,
distribute=None if request.precision or request.quantize else "auto"
)
# If manual distribution is needed
if request.precision or request.quantize:
llm_instance.distribute(
accelerator="cuda" if request.mode == "gpu" else "cpu",
devices=request.gpu_count,
precision=request.precision,
quantize=request.quantize
)
logger.info(
f"Model initialized successfully with config:\n"
f"Mode: {request.mode}\n"
f"Precision: {request.precision}\n"
f"Quantize: {request.quantize}\n"
f"GPU Count: {request.gpu_count}\n"
f"Model Path: {model_path}\n"
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
)
return {"success": True, "message": "Model initialized successfully"}
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
# Print detailed memory statistics on failure
logger.error(f"GPU Memory Stats:\n"
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
@app.post("/generate")
async def generate(request: GenerateRequest):
"""
Generate text using the initialized model.
"""
global llm_instance
if llm_instance is None:
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
try:
if request.stream:
raise HTTPException(
status_code=400,
detail="Streaming is not currently supported through the API"
)
generated_text = llm_instance.generate(
prompt=request.prompt,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
return_as_token_ids=request.return_as_token_ids,
stream=False # Force stream to False for now
)
response = {
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
"metadata": {
"prompt": request.prompt,
"max_new_tokens": request.max_new_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p
}
}
return response
except Exception as e:
logger.error(f"Error generating text: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
@app.get("/health")
async def health_check():
"""
Check if the service is running and model is loaded.
"""
global llm_instance
status = {
"status": "healthy",
"model_loaded": llm_instance is not None,
}
if llm_instance is not None:
logger.info(f"llm_instance is: {llm_instance}")
status["model_info"] = {
"model_path": llm_instance.config.name,
"device": str(next(llm_instance.model.parameters()).device)
}
return status
def main():
# Load environment variables or configuration here
host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
port = int(os.getenv("LLM_ENGINE_PORT", "7860")) # Default to 7860 for Spaces
# Log startup information
logger.info(f"Starting LLM Engine service on {host}:{port}")
logger.info("Available endpoints:")
logger.info(" - /")
logger.info(" - /health")
logger.info(" - /initialize")
logger.info(" - /generate")
logger.info(" - /docs")
logger.info(" - /redoc")
logger.info(" - /openapi.json")
# Start the server
uvicorn.run(
app,
host=host,
port=port,
log_level="info"
)
if __name__ == "__main__":
main() |