|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import os |
|
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-2b-it") |
|
DEVICE = "cpu" |
|
TORCH_DTYPE = torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading model: {MODEL_NAME} on {DEVICE} with dtype {TORCH_DTYPE}...") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=TORCH_DTYPE, |
|
|
|
|
|
|
|
) |
|
model.to(DEVICE) |
|
print(f"Model {MODEL_NAME} loaded successfully on {DEVICE}.") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
|
|
|
|
raise RuntimeError(f"Failed to load model: {e}") from e |
|
|
|
|
|
|
|
app = FastAPI( |
|
title="Gemma CPU Inference API", |
|
description="API to run inference on a Gemma model using CPU.", |
|
version="0.1.0" |
|
) |
|
|
|
class GenerationRequest(BaseModel): |
|
prompt: str |
|
max_new_tokens: int = 50 |
|
temperature: float = 0.7 |
|
do_sample: bool = True |
|
|
|
class GenerationResponse(BaseModel): |
|
generated_text: str |
|
input_prompt: str |
|
|
|
@app.post("/generate", response_model=GenerationResponse) |
|
async def generate_text(request: GenerationRequest): |
|
""" |
|
Generates text based on the input prompt using the loaded Gemma model. |
|
""" |
|
if not model or not tokenizer: |
|
raise HTTPException(status_code=503, detail="Model not loaded or failed to load.") |
|
|
|
print(f"Received request: {request.prompt[:50]}...") |
|
|
|
try: |
|
|
|
|
|
chat = [ |
|
{ "role": "user", "content": request.prompt }, |
|
] |
|
formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
|
input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) |
|
|
|
print(f"Generating text with max_new_tokens={request.max_new_tokens}, temperature={request.temperature}...") |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**input_ids, |
|
max_new_tokens=request.max_new_tokens, |
|
temperature=request.temperature, |
|
do_sample=request.do_sample, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
if full_text.startswith(formatted_prompt.replace("<bos>", "").replace("<eos>", "")): |
|
decoded_text = full_text[len(formatted_prompt.replace("<bos>", "").replace("<eos>", "")):] |
|
else: |
|
|
|
|
|
assistant_turn_start = "<start_of_turn>model\n" |
|
if assistant_turn_start in full_text: |
|
decoded_text = full_text.split(assistant_turn_start, 1)[-1] |
|
else: |
|
|
|
|
|
decoded_text = tokenizer.decode(outputs[0, input_ids.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
print(f"Generated: {decoded_text[:100]}...") |
|
|
|
return GenerationResponse(generated_text=decoded_text.strip(), input_prompt=request.prompt) |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during generation: {str(e)}") |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Gemma CPU Inference API is running. POST to /generate for inference."} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |