from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict, Any from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer import torch import os app = FastAPI() # Define the request schema class PromptRequest(BaseModel): prompt: str history: Optional[List[Dict[str, Any]]] = None parameters: Optional[Dict[str, Any]] = None @app.on_event("startup") def load_model(): global model, tokenizer, pipe os.environ["TRANSFORMERS_CACHE"] = "./cache" model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95" tokenizer = AutoTokenizer.from_pretrained(model_path) streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True) model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, cache_dir="./cache") pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer) @app.post("/generate/") async def generate_response(request: PromptRequest): # Format the prompt with message history history_text = "" if request.history: for message in request.history: role = message.get("role", "user") content = message.get("content", "") history_text += f"{role}: {content}\n" # Combine history with the current prompt full_prompt = f"{history_text}\nUser: {request.prompt}\nAssistant:" # Set default parameters and update with any provided gen_params = { "max_new_tokens": 256, "temperature": 0.7, "top_p": 0.9, } if request.parameters: gen_params.update(request.parameters) # Generate the response try: result = pipe(full_prompt, **gen_params) return {"response": result[0]["generated_text"]} except Exception as e: raise HTTPException(status_code=500, detail=str(e))