Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Depends | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, EmailStr | |
from fastapi_mail import FastMail, MessageSchema, ConnectionConfig | |
import json | |
from typing import List, Optional | |
import os | |
from scripts.predictor import create_pipe, predict | |
pipe = create_pipe() | |
class PredictionRequest (BaseModel): | |
# email: Optional[EmailStr] = None | |
context: str | |
prompt: str | |
tokenize: bool = False | |
add_generation_prompt: bool = True | |
max_new_tokens: int = 256 | |
do_sample: bool = True | |
temperature: float = 0.7 | |
top_k: int = 50 | |
top_p: float = 0.95 | |
def create_pipe(self): | |
return create_pipe() | |
class PredictionBatchRequest (BaseModel): | |
# email: Optional[EmailStr] = None | |
json_file: UploadFile = File(...) | |
tokenize: bool = False | |
add_generation_prompt: bool = True | |
max_new_tokens: int = 256 | |
do_sample: bool = True | |
temperature: float = 0.7 | |
top_k: int = 50 | |
top_p: float = 0.95 | |
def create_pipe(self): | |
return create_pipe() | |
class Prediction (BaseModel): | |
content: str | |
app = FastAPI( | |
title="Code-llama-7b-databases-finetuned2-DEMO API", | |
description="Rest API for serving LLM model predictions", | |
version="1.0.0", | |
) | |
# Configure your email server | |
# conf = ConnectionConfig( | |
# MAIL_USERNAME = os.getenv('MAIL_USERNAME'), | |
# MAIL_PASSWORD = os.getenv('MAIL_PASSWORD'), | |
# MAIL_FROM = os.getenv('MAIL_FROM'), | |
# MAIL_PORT = int(os.getenv('MAIL_PORT', '587')), | |
# MAIL_SERVER = os.getenv('MAIL_SERVER', 'smtp.gmail.com'), | |
# MAIL_STARTTLS = os.getenv("MAIL_STARTTLS", 'True').lower() in ('true', '1', 't'), | |
# MAIL_SSL_TLS = os.getenv("MAIL_SSL_TLS", 'False').lower() in ('true', '1', 't'), | |
# USE_CREDENTIALS = os.getenv("USE_CREDENTIALS", 'True').lower() in ('true', '1', 't'), | |
# VALIDATE_CERTS = os.getenv("VALIDATE_CERTS", 'True').lower() in ('true', '1', 't') | |
# ) | |
# Add middleware for handling Cross-Origin Resource Sharing (CORS) | |
app.add_middleware( | |
CORSMiddleware, | |
# allow_origins specifies which origins are allowed to access the resource. | |
# "*" means any origin is allowed. In production, replace this with a list of trusted domains. | |
allow_origins=["*"], | |
# allow_credentials specifies whether the browser should include credentials (cookies, authorization headers, etc.) | |
# with requests. Set to True to allow credentials to be sent. | |
allow_credentials=True, | |
# allow_methods specifies which HTTP methods are allowed when accessing the resource. | |
# "*" means all HTTP methods (GET, POST, PUT, DELETE, etc.) are allowed. | |
allow_methods=["*"], | |
# allow_headers specifies which HTTP headers can be used when making the actual request. | |
# "*" means all headers are allowed. | |
allow_headers=["*"], | |
) | |
async def security_headers(request: Request, call_next): | |
response = await call_next(request) # Process the request and get the response | |
response.headers["X-Content-Type-Options"] = "nosniff" # Prevent MIME type sniffing | |
response.headers["Content-Security-Policy"] = "frame-ancestors 'self' huggingface.co" # Prevent clickjacking | |
response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains" # Enforce HTTPS | |
response.headers["X-XSS-Protection"] = "1; mode=block" # Enable XSS filter in browsers | |
return response # Return the response with the added security headers | |
async def heartbeat(): | |
return {"status": "healthy"} | |
async def make_prediction(request: PredictionRequest): | |
try: | |
# pipe = request.create_pipe() | |
predictions = [] | |
prediction = predict( | |
context=request.context, | |
prompt=request.prompt, | |
pipe=pipe, | |
tokenize=request.tokenize, | |
add_generation_prompt=request.add_generation_prompt, | |
max_new_tokens=request.max_new_tokens, | |
do_sample=request.do_sample, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p | |
) | |
# # If the user provided an email, send the prediction result via email | |
# if request.email: | |
# await send_email(request.email, content) | |
predictions.append(Prediction(content=prediction)) | |
return predictions | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def make_batch_prediction(request: PredictionBatchRequest = Depends()): | |
try: | |
if not request.json_file: | |
raise HTTPException(status_code=400, detail="No JSON file provided.") | |
content = await request.json_file.read() | |
data = json.loads(content) | |
if not isinstance(data, list): | |
raise HTTPException(status_code=400, detail="Invalid JSON format. Expected a list of JSON objects.") | |
# pipe = request.create_pipe() | |
predictions = [] | |
for item in data: | |
try: | |
context = item.get('context', 'Provide an answer to the following question:') | |
prompt = item['prompt'] | |
prediction = predict( | |
context=context, | |
prompt=prompt, | |
pipe=pipe, | |
tokenize=request.tokenize, | |
add_generation_prompt=request.add_generation_prompt, | |
max_new_tokens=request.max_new_tokens, | |
do_sample=request.do_sample, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p | |
) | |
predictions.append(Prediction(content=prediction)) | |
except KeyError: | |
raise HTTPException(status_code=400, detail="Each JSON object must contain at least a 'prompt' field.") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# # If the user provided an email, send the prediction result via email | |
# if request.email: | |
# await send_email(request.email, content) | |
return predictions | |
except json.JSONDecodeError: | |
raise HTTPException(status_code=400, detail="Invalid JSON file.") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# # Function to send email | |
# async def send_email(email: str, content: List[dict]): | |
# # Construct the email body by iterating through the list of content objects | |
# email_body = "<h1>Your AI Generated Answers</h1>" | |
# for item in content: | |
# instruction = item.get('instruction', 'Provide an answer to the following question:') | |
# input_text = item['input'] | |
# output_text = item['output'] | |
# email_body += f""" | |
# <h2>Instruction:</h2> | |
# <p>{instruction}</p> | |
# <h2>Input:</h2> | |
# <p>{input_text}</p> | |
# <h2>Output:</h2> | |
# <p>{output_text}</p> | |
# <hr> | |
# """ | |
# message = MessageSchema( | |
# subject="Your AI Generated Answers", | |
# recipients=[email], | |
# html=email_body, | |
# subtype="html" | |
# ) | |
# fm = FastMail(conf) | |
# await fm.send_message(message) | |
# # Ensure your email configuration works | |
# @app.get("/test-email") | |
# async def test_email(): | |
# try: | |
# await send_email(os.getenv('TEST_EMAIL'), [{ | |
# "instruction": "This is a test instruction.", | |
# "input": "This is a test input.", | |
# "output": "This is a test output.", | |
# }]) | |
# | |
# return {"message": "Test email sent successfully"} | |
# except Exception as e: | |
# raise HTTPException(status_code=500, detail=str(e)) | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") |