Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import logging | |
import os | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="SQL Coder API") | |
# Set environment variable for cache directory | |
os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface' | |
# Initialize pipeline | |
try: | |
pipe = pipeline( | |
"text-generation", | |
model="defog/llama-3-sqlcoder-8b", | |
device_map="auto", | |
model_kwargs={"torch_dtype": "auto"} | |
) | |
logger.info("Pipeline initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing pipeline: {str(e)}") | |
raise | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class QueryRequest(BaseModel): | |
messages: list[ChatMessage] | |
max_new_tokens: int = 1024 | |
temperature: float = 0.7 | |
class QueryResponse(BaseModel): | |
generated_text: str | |
async def generate(request: QueryRequest): | |
try: | |
# Format messages into a single string | |
formatted_prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages]) | |
# Generate response using pipeline | |
response = pipe( | |
formatted_prompt, | |
max_new_tokens=request.max_new_tokens, | |
temperature=request.temperature, | |
do_sample=True, | |
num_return_sequences=1, | |
pad_token_id=pipe.tokenizer.eos_token_id | |
) | |
# Extract generated text | |
generated_text = response[0]['generated_text'] | |
return QueryResponse(generated_text=generated_text) | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy"} |