mistrial-props / app.py
novamysticX's picture
Update app.py
fca0532 verified
raw
history blame
1.95 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import logging
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="SQL Coder API")
# Set environment variable for cache directory
os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface'
# Initialize pipeline
try:
pipe = pipeline(
"text-generation",
model="defog/llama-3-sqlcoder-8b",
device_map="auto",
model_kwargs={"torch_dtype": "auto"}
)
logger.info("Pipeline initialized successfully")
except Exception as e:
logger.error(f"Error initializing pipeline: {str(e)}")
raise
class ChatMessage(BaseModel):
role: str
content: str
class QueryRequest(BaseModel):
messages: list[ChatMessage]
max_new_tokens: int = 1024
temperature: float = 0.7
class QueryResponse(BaseModel):
generated_text: str
@app.post("/generate", response_model=QueryResponse)
async def generate(request: QueryRequest):
try:
# Format messages into a single string
formatted_prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
# Generate response using pipeline
response = pipe(
formatted_prompt,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
do_sample=True,
num_return_sequences=1,
pad_token_id=pipe.tokenizer.eos_token_id
)
# Extract generated text
generated_text = response[0]['generated_text']
return QueryResponse(generated_text=generated_text)
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}