from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os # --- Configuration --- MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-2b-it") # Or "google/gemma-7b-it" if you have resources DEVICE = "cpu" # Explicitly set to CPU TORCH_DTYPE = torch.float32 # Use float32 for CPU for broader compatibility and stability # For some newer CPUs, bfloat16 might offer speedups if supported # but can sometimes be less stable or require specific setups. # --- Model Loading --- # This will run when the Docker container starts, or when the app is first imported. # It might take a few minutes for larger models. print(f"Loading model: {MODEL_NAME} on {DEVICE} with dtype {TORCH_DTYPE}...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=TORCH_DTYPE, # low_cpu_mem_usage=True, # Can be useful for very large models on CPU, but might slow down loading # device_map="auto" # 'auto' will select CPU if no GPU is available or if specified. # Forcing CPU ensures no GPU attempts. ) model.to(DEVICE) # Ensure model is on CPU print(f"Model {MODEL_NAME} loaded successfully on {DEVICE}.") except Exception as e: print(f"Error loading model: {e}") # If model loading fails, we can't serve requests. # Depending on deployment, you might want to exit or handle this differently. raise RuntimeError(f"Failed to load model: {e}") from e # --- FastAPI App --- app = FastAPI( title="Gemma CPU Inference API", description="API to run inference on a Gemma model using CPU.", version="0.1.0" ) class GenerationRequest(BaseModel): prompt: str max_new_tokens: int = 50 temperature: float = 0.7 do_sample: bool = True class GenerationResponse(BaseModel): generated_text: str input_prompt: str @app.post("/generate", response_model=GenerationResponse) async def generate_text(request: GenerationRequest): """ Generates text based on the input prompt using the loaded Gemma model. """ if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model not loaded or failed to load.") print(f"Received request: {request.prompt[:50]}...") # Log snippet of prompt try: # Format prompt for instruction-tuned models (like gemma-*-it) # This is a common format, adjust if your model expects something different chat = [ { "role": "user", "content": request.prompt }, ] formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) print(f"Generating text with max_new_tokens={request.max_new_tokens}, temperature={request.temperature}...") with torch.no_grad(): # Important for inference outputs = model.generate( **input_ids, max_new_tokens=request.max_new_tokens, temperature=request.temperature, do_sample=request.do_sample, # Add other generation parameters as needed: top_k, top_p, etc. ) # Decode the generated text (only the new tokens) # The generated output includes the input prompt, so we slice it off. # For some models, the slice point might need adjustment. # decoded_text = tokenizer.decode(outputs[0, input_ids.input_ids.shape[1]:], skip_special_tokens=True) # A more robust way to get only the generated part, especially with chat templates full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the prompt part. This depends on how apply_chat_template works. # For many models, the prompt itself is part of the output of apply_chat_template. # A simple way if the prompt is directly prepended: if full_text.startswith(formatted_prompt.replace("", "").replace("", "")): # Handle potential BOS/EOS tokens in prompt decoded_text = full_text[len(formatted_prompt.replace("", "").replace("", "")):] else: # Fallback or more sophisticated stripping might be needed depending on the template # For Gemma's instruction-tuned template, this usually works by finding the assistant's turn start assistant_turn_start = "model\n" if assistant_turn_start in full_text: decoded_text = full_text.split(assistant_turn_start, 1)[-1] else: # If not found, it might be that the prompt itself wasn't fully included in the output # or the template is different. As a simpler fallback, we take the part after input_ids. decoded_text = tokenizer.decode(outputs[0, input_ids.input_ids.shape[1]:], skip_special_tokens=True) print(f"Generated: {decoded_text[:100]}...") return GenerationResponse(generated_text=decoded_text.strip(), input_prompt=request.prompt) except Exception as e: print(f"Error during generation: {e}") raise HTTPException(status_code=500, detail=f"Error during generation: {str(e)}") @app.get("/") async def root(): return {"message": "Gemma CPU Inference API is running. POST to /generate for inference."} # To run locally (optional, uvicorn in CMD will handle it in Docker) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)