Do0rMaMu's picture
Upload folder using huggingface_hub
6bce6bf verified
raw
history blame
1.88 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
import torch
app = FastAPI()
# Define the request schema
class PromptRequest(BaseModel):
prompt: str
history: Optional[List[Dict[str, Any]]] = None
parameters: Optional[Dict[str, Any]] = None
@app.on_event("startup")
def load_model():
global model, tokenizer, pipe
model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer)
@app.post("/generate/")
async def generate_response(request: PromptRequest):
# Format the prompt with message history
history_text = ""
if request.history:
for message in request.history:
role = message.get("role", "user")
content = message.get("content", "")
history_text += f"{role}: {content}\n"
# Combine history with the current prompt
full_prompt = f"{history_text}\nUser: {request.prompt}\nAssistant:"
# Set default parameters and update with any provided
gen_params = {
"max_new_tokens": 256,
"temperature": 0.7,
"top_p": 0.9,
}
if request.parameters:
gen_params.update(request.parameters)
# Generate the response
try:
result = pipe(full_prompt, **gen_params)
return {"response": result[0]["generated_text"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))