File size: 8,371 Bytes
7359499
b7e12a6
 
a49687b
3c13c0a
 
5c74d30
3c13c0a
 
a49687b
b7e12a6
a49687b
8e209ff
 
a49687b
595cd98
b7e12a6
 
 
 
 
 
 
fd7c6e1
b7e12a6
a49687b
da48f71
 
74cb11c
 
595cd98
74cb11c
 
 
 
 
 
 
 
 
 
 
da48f71
a49687b
 
 
 
b70e333
a49687b
b7e12a6
a49687b
 
3c13c0a
595cd98
 
 
 
6f86aef
595cd98
6f86aef
 
595cd98
 
 
3c13c0a
29e0ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7c6e1
29e0ab0
 
 
 
 
08aa285
a49687b
 
 
521a01f
da48f71
a49687b
8e209ff
521a01f
 
 
11df283
da48f71
 
 
 
 
 
 
 
 
 
a49687b
30a5f43
595cd98
 
 
3c13c0a
521a01f
 
 
a49687b
0cc5b82
 
5c74d30
07c6902
5c74d30
74cb11c
 
 
 
5c74d30
 
 
 
 
8e209ff
5c74d30
 
 
 
11df283
 
 
 
 
 
5c74d30
74cb11c
 
 
 
 
 
 
5c74d30
30a5f43
11df283
5c74d30
 
 
 
3c13c0a
595cd98
 
 
5c74d30
 
 
 
 
 
 
 
595cd98
 
 
 
 
 
 
 
3c13c0a
595cd98
 
 
 
 
 
 
 
 
3c13c0a
595cd98
 
 
6f86aef
595cd98
 
 
 
 
 
 
 
 
 
 
6f86aef
d72b0e7
 
 
 
 
595cd98
 
 
3c13c0a
0cc5b82
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Depends
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
from fastapi_mail import FastMail, MessageSchema, ConnectionConfig
import json
from typing import List, Optional
import os

from scripts.predictor import create_pipe, predict

pipe = create_pipe()

class PredictionRequest (BaseModel):
    # email: Optional[EmailStr] = None
    context: str    
    prompt: str  
    tokenize: bool = False
    add_generation_prompt: bool = True
    max_new_tokens: int = 256
    do_sample: bool = True
    temperature: float = 0.7
    top_k: int = 50
    top_p: float = 0.95

    def create_pipe(self):
        return create_pipe()
    
class PredictionBatchRequest (BaseModel):
    # email: Optional[EmailStr] = None
    json_file: UploadFile = File(...)
    tokenize: bool = False
    add_generation_prompt: bool = True
    max_new_tokens: int = 256
    do_sample: bool = True
    temperature: float = 0.7
    top_k: int = 50
    top_p: float = 0.95

    def create_pipe(self):
        return create_pipe()

class Prediction (BaseModel):
    content: str

app = FastAPI(
    title="Code-llama-7b-databases-finetuned2-DEMO API",
    description="Rest API for serving LLM model predictions",
    version="1.0.0",
)

# Configure your email server
# conf = ConnectionConfig(
#     MAIL_USERNAME = os.getenv('MAIL_USERNAME'),
#     MAIL_PASSWORD = os.getenv('MAIL_PASSWORD'),
#     MAIL_FROM = os.getenv('MAIL_FROM'),
#     MAIL_PORT = int(os.getenv('MAIL_PORT', '587')),
#     MAIL_SERVER = os.getenv('MAIL_SERVER', 'smtp.gmail.com'),
#     MAIL_STARTTLS = os.getenv("MAIL_STARTTLS", 'True').lower() in ('true', '1', 't'),
#     MAIL_SSL_TLS = os.getenv("MAIL_SSL_TLS", 'False').lower() in ('true', '1', 't'),
#     USE_CREDENTIALS = os.getenv("USE_CREDENTIALS", 'True').lower() in ('true', '1', 't'),
#     VALIDATE_CERTS = os.getenv("VALIDATE_CERTS", 'True').lower() in ('true', '1', 't')
# )

# Add middleware for handling Cross-Origin Resource Sharing (CORS)
app.add_middleware(
    CORSMiddleware,
    # allow_origins specifies which origins are allowed to access the resource.
    # "*" means any origin is allowed. In production, replace this with a list of trusted domains.
    allow_origins=["*"],
    # allow_credentials specifies whether the browser should include credentials (cookies, authorization headers, etc.)
    # with requests. Set to True to allow credentials to be sent.
    allow_credentials=True,
    # allow_methods specifies which HTTP methods are allowed when accessing the resource.
    # "*" means all HTTP methods (GET, POST, PUT, DELETE, etc.) are allowed.
    allow_methods=["*"],
    # allow_headers specifies which HTTP headers can be used when making the actual request.
    # "*" means all headers are allowed.
    allow_headers=["*"],
)

@app.middleware("http")
async def security_headers(request: Request, call_next):
    response = await call_next(request)  # Process the request and get the response
    response.headers["X-Content-Type-Options"] = "nosniff"  # Prevent MIME type sniffing
    response.headers["Content-Security-Policy"] = "frame-ancestors 'self' huggingface.co"  # Prevent clickjacking
    response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains"  # Enforce HTTPS
    response.headers["X-XSS-Protection"] = "1; mode=block"  # Enable XSS filter in browsers

    return response  # Return the response with the added security headers

@app.get("/heartbeat")
async def heartbeat():
    return {"status": "healthy"}

@app.post("/predict", response_model=List[Prediction], status_code=200)
async def make_prediction(request: PredictionRequest):
    try:
        # pipe = request.create_pipe()
        
        predictions = []
        
        prediction = predict(
            context=request.context,    
            prompt=request.prompt,
            pipe=pipe,
            tokenize=request.tokenize,
            add_generation_prompt=request.add_generation_prompt,
            max_new_tokens=request.max_new_tokens,
            do_sample=request.do_sample,
            temperature=request.temperature,
            top_k=request.top_k,
            top_p=request.top_p
        )

        # # If the user provided an email, send the prediction result via email
        # if request.email:
        #     await send_email(request.email, content)

        predictions.append(Prediction(content=prediction))
        
        return predictions
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    
@app.post("/predict_batch", response_model=List[Prediction], status_code=200)
async def make_batch_prediction(request: PredictionBatchRequest = Depends()):
    try:
        if not request.json_file:
            raise HTTPException(status_code=400, detail="No JSON file provided.")
        
        content = await request.json_file.read()
        data = json.loads(content)
        
        if not isinstance(data, list):
            raise HTTPException(status_code=400, detail="Invalid JSON format. Expected a list of JSON objects.")
        
        # pipe = request.create_pipe()
        predictions = []
        
        for item in data:
            try:
                context = item.get('context', 'Provide an answer to the following question:')
                prompt = item['prompt']

                prediction = predict(
                    context=context,
                    prompt=prompt,
                    pipe=pipe,
                    tokenize=request.tokenize,
                    add_generation_prompt=request.add_generation_prompt,
                    max_new_tokens=request.max_new_tokens,
                    do_sample=request.do_sample,
                    temperature=request.temperature,
                    top_k=request.top_k,
                    top_p=request.top_p
                )

                predictions.append(Prediction(content=prediction))
            except KeyError:
                raise HTTPException(status_code=400, detail="Each JSON object must contain at least a 'prompt' field.")
            except Exception as e:
                raise HTTPException(status_code=500, detail=str(e))
            
        # # If the user provided an email, send the prediction result via email
        # if request.email:
        #     await send_email(request.email, content)
        
        return predictions
    
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="Invalid JSON file.")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    
# # Function to send email
# async def send_email(email: str, content: List[dict]):
#     # Construct the email body by iterating through the list of content objects
#     email_body = "<h1>Your AI Generated Answers</h1>"
#     for item in content:
#         instruction = item.get('instruction', 'Provide an answer to the following question:')
#         input_text = item['input']
#         output_text = item['output']
        
#         email_body += f"""
#         <h2>Instruction:</h2>
#         <p>{instruction}</p>
#         <h2>Input:</h2>
#         <p>{input_text}</p>
#         <h2>Output:</h2>
#         <p>{output_text}</p>
#         <hr>
#         """
    
#     message = MessageSchema(
#         subject="Your AI Generated Answers",
#         recipients=[email],
#         html=email_body,
#         subtype="html"
#     )

#     fm = FastMail(conf)
#     await fm.send_message(message)


# # Ensure your email configuration works
# @app.get("/test-email")
# async def test_email():
#     try:
#         await send_email(os.getenv('TEST_EMAIL'), [{
#                "instruction": "This is a test instruction.",
#                "input": "This is a test input.",
#                "output": "This is a test output.",
#            }])
#            
#         return {"message": "Test email sent successfully"}
#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))
    
app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")