Spaces:
Runtime error
Runtime error
File size: 8,371 Bytes
7359499 b7e12a6 a49687b 3c13c0a 5c74d30 3c13c0a a49687b b7e12a6 a49687b 8e209ff a49687b 595cd98 b7e12a6 fd7c6e1 b7e12a6 a49687b da48f71 74cb11c 595cd98 74cb11c da48f71 a49687b b70e333 a49687b b7e12a6 a49687b 3c13c0a 595cd98 6f86aef 595cd98 6f86aef 595cd98 3c13c0a 29e0ab0 fd7c6e1 29e0ab0 08aa285 a49687b 521a01f da48f71 a49687b 8e209ff 521a01f 11df283 da48f71 a49687b 30a5f43 595cd98 3c13c0a 521a01f a49687b 0cc5b82 5c74d30 07c6902 5c74d30 74cb11c 5c74d30 8e209ff 5c74d30 11df283 5c74d30 74cb11c 5c74d30 30a5f43 11df283 5c74d30 3c13c0a 595cd98 5c74d30 595cd98 3c13c0a 595cd98 3c13c0a 595cd98 6f86aef 595cd98 6f86aef d72b0e7 595cd98 3c13c0a 0cc5b82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Depends
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
from fastapi_mail import FastMail, MessageSchema, ConnectionConfig
import json
from typing import List, Optional
import os
from scripts.predictor import create_pipe, predict
pipe = create_pipe()
class PredictionRequest (BaseModel):
# email: Optional[EmailStr] = None
context: str
prompt: str
tokenize: bool = False
add_generation_prompt: bool = True
max_new_tokens: int = 256
do_sample: bool = True
temperature: float = 0.7
top_k: int = 50
top_p: float = 0.95
def create_pipe(self):
return create_pipe()
class PredictionBatchRequest (BaseModel):
# email: Optional[EmailStr] = None
json_file: UploadFile = File(...)
tokenize: bool = False
add_generation_prompt: bool = True
max_new_tokens: int = 256
do_sample: bool = True
temperature: float = 0.7
top_k: int = 50
top_p: float = 0.95
def create_pipe(self):
return create_pipe()
class Prediction (BaseModel):
content: str
app = FastAPI(
title="Code-llama-7b-databases-finetuned2-DEMO API",
description="Rest API for serving LLM model predictions",
version="1.0.0",
)
# Configure your email server
# conf = ConnectionConfig(
# MAIL_USERNAME = os.getenv('MAIL_USERNAME'),
# MAIL_PASSWORD = os.getenv('MAIL_PASSWORD'),
# MAIL_FROM = os.getenv('MAIL_FROM'),
# MAIL_PORT = int(os.getenv('MAIL_PORT', '587')),
# MAIL_SERVER = os.getenv('MAIL_SERVER', 'smtp.gmail.com'),
# MAIL_STARTTLS = os.getenv("MAIL_STARTTLS", 'True').lower() in ('true', '1', 't'),
# MAIL_SSL_TLS = os.getenv("MAIL_SSL_TLS", 'False').lower() in ('true', '1', 't'),
# USE_CREDENTIALS = os.getenv("USE_CREDENTIALS", 'True').lower() in ('true', '1', 't'),
# VALIDATE_CERTS = os.getenv("VALIDATE_CERTS", 'True').lower() in ('true', '1', 't')
# )
# Add middleware for handling Cross-Origin Resource Sharing (CORS)
app.add_middleware(
CORSMiddleware,
# allow_origins specifies which origins are allowed to access the resource.
# "*" means any origin is allowed. In production, replace this with a list of trusted domains.
allow_origins=["*"],
# allow_credentials specifies whether the browser should include credentials (cookies, authorization headers, etc.)
# with requests. Set to True to allow credentials to be sent.
allow_credentials=True,
# allow_methods specifies which HTTP methods are allowed when accessing the resource.
# "*" means all HTTP methods (GET, POST, PUT, DELETE, etc.) are allowed.
allow_methods=["*"],
# allow_headers specifies which HTTP headers can be used when making the actual request.
# "*" means all headers are allowed.
allow_headers=["*"],
)
@app.middleware("http")
async def security_headers(request: Request, call_next):
response = await call_next(request) # Process the request and get the response
response.headers["X-Content-Type-Options"] = "nosniff" # Prevent MIME type sniffing
response.headers["Content-Security-Policy"] = "frame-ancestors 'self' huggingface.co" # Prevent clickjacking
response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains" # Enforce HTTPS
response.headers["X-XSS-Protection"] = "1; mode=block" # Enable XSS filter in browsers
return response # Return the response with the added security headers
@app.get("/heartbeat")
async def heartbeat():
return {"status": "healthy"}
@app.post("/predict", response_model=List[Prediction], status_code=200)
async def make_prediction(request: PredictionRequest):
try:
# pipe = request.create_pipe()
predictions = []
prediction = predict(
context=request.context,
prompt=request.prompt,
pipe=pipe,
tokenize=request.tokenize,
add_generation_prompt=request.add_generation_prompt,
max_new_tokens=request.max_new_tokens,
do_sample=request.do_sample,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p
)
# # If the user provided an email, send the prediction result via email
# if request.email:
# await send_email(request.email, content)
predictions.append(Prediction(content=prediction))
return predictions
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict_batch", response_model=List[Prediction], status_code=200)
async def make_batch_prediction(request: PredictionBatchRequest = Depends()):
try:
if not request.json_file:
raise HTTPException(status_code=400, detail="No JSON file provided.")
content = await request.json_file.read()
data = json.loads(content)
if not isinstance(data, list):
raise HTTPException(status_code=400, detail="Invalid JSON format. Expected a list of JSON objects.")
# pipe = request.create_pipe()
predictions = []
for item in data:
try:
context = item.get('context', 'Provide an answer to the following question:')
prompt = item['prompt']
prediction = predict(
context=context,
prompt=prompt,
pipe=pipe,
tokenize=request.tokenize,
add_generation_prompt=request.add_generation_prompt,
max_new_tokens=request.max_new_tokens,
do_sample=request.do_sample,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p
)
predictions.append(Prediction(content=prediction))
except KeyError:
raise HTTPException(status_code=400, detail="Each JSON object must contain at least a 'prompt' field.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# # If the user provided an email, send the prediction result via email
# if request.email:
# await send_email(request.email, content)
return predictions
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON file.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# # Function to send email
# async def send_email(email: str, content: List[dict]):
# # Construct the email body by iterating through the list of content objects
# email_body = "<h1>Your AI Generated Answers</h1>"
# for item in content:
# instruction = item.get('instruction', 'Provide an answer to the following question:')
# input_text = item['input']
# output_text = item['output']
# email_body += f"""
# <h2>Instruction:</h2>
# <p>{instruction}</p>
# <h2>Input:</h2>
# <p>{input_text}</p>
# <h2>Output:</h2>
# <p>{output_text}</p>
# <hr>
# """
# message = MessageSchema(
# subject="Your AI Generated Answers",
# recipients=[email],
# html=email_body,
# subtype="html"
# )
# fm = FastMail(conf)
# await fm.send_message(message)
# # Ensure your email configuration works
# @app.get("/test-email")
# async def test_email():
# try:
# await send_email(os.getenv('TEST_EMAIL'), [{
# "instruction": "This is a test instruction.",
# "input": "This is a test input.",
# "output": "This is a test output.",
# }])
#
# return {"message": "Test email sent successfully"}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html") |