from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Depends from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, EmailStr from fastapi_mail import FastMail, MessageSchema, ConnectionConfig import json from typing import List, Optional import os from scripts.predictor import create_pipe, predict pipe = create_pipe() class PredictionRequest (BaseModel): # email: Optional[EmailStr] = None context: str prompt: str tokenize: bool = False add_generation_prompt: bool = True max_new_tokens: int = 256 do_sample: bool = True temperature: float = 0.7 top_k: int = 50 top_p: float = 0.95 def create_pipe(self): return create_pipe() class PredictionBatchRequest (BaseModel): # email: Optional[EmailStr] = None json_file: UploadFile = File(...) tokenize: bool = False add_generation_prompt: bool = True max_new_tokens: int = 256 do_sample: bool = True temperature: float = 0.7 top_k: int = 50 top_p: float = 0.95 def create_pipe(self): return create_pipe() class Prediction (BaseModel): content: str app = FastAPI( title="Code-llama-7b-databases-finetuned2-DEMO API", description="Rest API for serving LLM model predictions", version="1.0.0", ) # Configure your email server # conf = ConnectionConfig( # MAIL_USERNAME = os.getenv('MAIL_USERNAME'), # MAIL_PASSWORD = os.getenv('MAIL_PASSWORD'), # MAIL_FROM = os.getenv('MAIL_FROM'), # MAIL_PORT = int(os.getenv('MAIL_PORT', '587')), # MAIL_SERVER = os.getenv('MAIL_SERVER', 'smtp.gmail.com'), # MAIL_STARTTLS = os.getenv("MAIL_STARTTLS", 'True').lower() in ('true', '1', 't'), # MAIL_SSL_TLS = os.getenv("MAIL_SSL_TLS", 'False').lower() in ('true', '1', 't'), # USE_CREDENTIALS = os.getenv("USE_CREDENTIALS", 'True').lower() in ('true', '1', 't'), # VALIDATE_CERTS = os.getenv("VALIDATE_CERTS", 'True').lower() in ('true', '1', 't') # ) # Add middleware for handling Cross-Origin Resource Sharing (CORS) app.add_middleware( CORSMiddleware, # allow_origins specifies which origins are allowed to access the resource. # "*" means any origin is allowed. In production, replace this with a list of trusted domains. allow_origins=["*"], # allow_credentials specifies whether the browser should include credentials (cookies, authorization headers, etc.) # with requests. Set to True to allow credentials to be sent. allow_credentials=True, # allow_methods specifies which HTTP methods are allowed when accessing the resource. # "*" means all HTTP methods (GET, POST, PUT, DELETE, etc.) are allowed. allow_methods=["*"], # allow_headers specifies which HTTP headers can be used when making the actual request. # "*" means all headers are allowed. allow_headers=["*"], ) @app.middleware("http") async def security_headers(request: Request, call_next): response = await call_next(request) # Process the request and get the response response.headers["X-Content-Type-Options"] = "nosniff" # Prevent MIME type sniffing response.headers["Content-Security-Policy"] = "frame-ancestors 'self' huggingface.co" # Prevent clickjacking response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains" # Enforce HTTPS response.headers["X-XSS-Protection"] = "1; mode=block" # Enable XSS filter in browsers return response # Return the response with the added security headers @app.get("/heartbeat") async def heartbeat(): return {"status": "healthy"} @app.post("/predict", response_model=List[Prediction], status_code=200) async def make_prediction(request: PredictionRequest): try: # pipe = request.create_pipe() predictions = [] prediction = predict( context=request.context, prompt=request.prompt, pipe=pipe, tokenize=request.tokenize, add_generation_prompt=request.add_generation_prompt, max_new_tokens=request.max_new_tokens, do_sample=request.do_sample, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p ) # # If the user provided an email, send the prediction result via email # if request.email: # await send_email(request.email, content) predictions.append(Prediction(content=prediction)) return predictions except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict_batch", response_model=List[Prediction], status_code=200) async def make_batch_prediction(request: PredictionBatchRequest = Depends()): try: if not request.json_file: raise HTTPException(status_code=400, detail="No JSON file provided.") content = await request.json_file.read() data = json.loads(content) if not isinstance(data, list): raise HTTPException(status_code=400, detail="Invalid JSON format. Expected a list of JSON objects.") # pipe = request.create_pipe() predictions = [] for item in data: try: context = item.get('context', 'Provide an answer to the following question:') prompt = item['prompt'] prediction = predict( context=context, prompt=prompt, pipe=pipe, tokenize=request.tokenize, add_generation_prompt=request.add_generation_prompt, max_new_tokens=request.max_new_tokens, do_sample=request.do_sample, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p ) predictions.append(Prediction(content=prediction)) except KeyError: raise HTTPException(status_code=400, detail="Each JSON object must contain at least a 'prompt' field.") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # # If the user provided an email, send the prediction result via email # if request.email: # await send_email(request.email, content) return predictions except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON file.") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # # Function to send email # async def send_email(email: str, content: List[dict]): # # Construct the email body by iterating through the list of content objects # email_body = "

Your AI Generated Answers

" # for item in content: # instruction = item.get('instruction', 'Provide an answer to the following question:') # input_text = item['input'] # output_text = item['output'] # email_body += f""" #

Instruction:

#

{instruction}

#

Input:

#

{input_text}

#

Output:

#

{output_text}

#
# """ # message = MessageSchema( # subject="Your AI Generated Answers", # recipients=[email], # html=email_body, # subtype="html" # ) # fm = FastMail(conf) # await fm.send_message(message) # # Ensure your email configuration works # @app.get("/test-email") # async def test_email(): # try: # await send_email(os.getenv('TEST_EMAIL'), [{ # "instruction": "This is a test instruction.", # "input": "This is a test input.", # "output": "This is a test output.", # }]) # # return {"message": "Test email sent successfully"} # except Exception as e: # raise HTTPException(status_code=500, detail=str(e)) app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html")