from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from torch import cuda from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login from dotenv import load_dotenv import os import uvicorn import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Required for access to a gated model load_dotenv() hf_token = os.getenv("HF_TOKEN", None) if hf_token is not None: login(token=hf_token) # Configurable model identifier model_name = os.getenv("HF_MODEL", "swiss-ai/Apertus-8B-Instruct-2509") # Keep data in session model = None tokenizer = None class TextInput(BaseModel): text: str min_length: int = 3 # Apertus by default supports a context length up to 65,536 tokens. max_length: int = 65536 class ModelResponse(BaseModel): text: str confidence: float processing_time: float @asynccontextmanager async def lifespan(app: FastAPI): """Load the transformer model on startup""" global model, tokenizer try: logger.info(f"Loading model: {model_name}") # Automatically select device based on availability device = "cuda" if cuda.is_available() else "cpu" # load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically splits model across CPU/GPU low_cpu_mem_usage=True, # Avoids unnecessary CPU memory duplication offload_folder="offload", # Temporary offload to disk ) #.to(device) logger.info(f"Model loaded successfully! ({device})") except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Release resources when the app is stopped yield del model del tokenizer cuda.empty_cache() # Setup our app app = FastAPI( title="Apertus API", description="REST API for serving Apertus models via Hugging Face transformers", version="0.1.0", docs_url="/", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/predict", response_model=ModelResponse) async def predict(q: str): """Generate a model response for input text""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: import time start_time = time.time() input_data = TextInput(text=q) # Truncate text if too long text = input_data.text[:input_data.max_length] if len(text) == input_data.max_length: logger.warning("Warning: text truncated") if len(text) < input_data.min_length: logger.warning("Warning: empty text, aborting") return None # Prepare the model input messages_think = [ {"role": "user", "content": text} ] text = tokenizer.apply_chat_template( messages_think, tokenize=False, add_generation_prompt=True, top_p=0.9, temperature=0.8, ) model_inputs = tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to(model.device) # Generate the output generated_ids = model.generate( **model_inputs, max_new_tokens=512 ) # Get and decode the output output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] result = tokenizer.decode(output_ids, skip_special_tokens=True) # Checkpoint processing_time = time.time() - start_time return ModelResponse( text=result, #['label'], confidence=0, #result['score'], processing_time=processing_time ) except HTTPException as e: logger.error(f"Evaluation error: {e}") raise HTTPException(status_code=500, detail="Evaluation failed") @app.get("/health") async def health_check(): """Health check and basic configuration""" return { "status": "healthy", "model_loaded": model is not None, "gpu_available": cuda.is_available() } if __name__=='__main__': uvicorn.run('app:app', reload=True)