from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline import logging import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="SQL Coder API") # Set environment variable for cache directory os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface' # Initialize pipeline try: pipe = pipeline( "text-generation", model="defog/llama-3-sqlcoder-8b", device_map="auto", model_kwargs={"torch_dtype": "auto"} ) logger.info("Pipeline initialized successfully") except Exception as e: logger.error(f"Error initializing pipeline: {str(e)}") raise class ChatMessage(BaseModel): role: str content: str class QueryRequest(BaseModel): messages: list[ChatMessage] max_new_tokens: int = 1024 temperature: float = 0.7 class QueryResponse(BaseModel): generated_text: str @app.post("/generate", response_model=QueryResponse) async def generate(request: QueryRequest): try: # Format messages into a single string formatted_prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages]) # Generate response using pipeline response = pipe( formatted_prompt, max_new_tokens=request.max_new_tokens, temperature=request.temperature, do_sample=True, num_return_sequences=1, pad_token_id=pipe.tokenizer.eos_token_id ) # Extract generated text generated_text = response[0]['generated_text'] return QueryResponse(generated_text=generated_text) except Exception as e: logger.error(f"Error generating response: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy"}