Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional, Dict, Any | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer | |
import torch | |
app = FastAPI() | |
# Define the request schema | |
class PromptRequest(BaseModel): | |
prompt: str | |
history: Optional[List[Dict[str, Any]]] = None | |
parameters: Optional[Dict[str, Any]] = None | |
def load_model(): | |
global model, tokenizer, pipe | |
model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True) | |
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer) | |
async def generate_response(request: PromptRequest): | |
# Format the prompt with message history | |
history_text = "" | |
if request.history: | |
for message in request.history: | |
role = message.get("role", "user") | |
content = message.get("content", "") | |
history_text += f"{role}: {content}\n" | |
# Combine history with the current prompt | |
full_prompt = f"{history_text}\nUser: {request.prompt}\nAssistant:" | |
# Set default parameters and update with any provided | |
gen_params = { | |
"max_new_tokens": 256, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
} | |
if request.parameters: | |
gen_params.update(request.parameters) | |
# Generate the response | |
try: | |
result = pipe(full_prompt, **gen_params) | |
return {"response": result[0]["generated_text"]} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |