Achyut Tiwari commited on
Commit
8c7ad44
1 Parent(s): e0ccdc1

Add files via upload

Browse files
lfqa_server/Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
2
+ #set up environment
3
+ RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y curl
4
+ RUN apt-get install unzip
5
+ RUN apt-get -y install python3
6
+ RUN apt-get -y install python3-pip
7
+
8
+ WORKDIR /code
9
+
10
+ ENV HF_HOME=/code/cache
11
+
12
+ COPY ./requirements.txt /code/requirements.txt
13
+
14
+ RUN pip3 install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
15
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
16
+
17
+ COPY ./main.py /code/app/main.py
18
+
19
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
lfqa_server/__init__.py ADDED
File without changes
lfqa_server/main.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Depends, status
3
+ from fastapi.responses import PlainTextResponse
4
+ from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ import time
8
+ from typing import Dict, List, Optional
9
+
10
+ import jwt
11
+ from decouple import config
12
+ from fastapi import Request, HTTPException
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
+
15
+ JWT_SECRET = config("secret")
16
+ JWT_ALGORITHM = config("algorithm")
17
+
18
+ app = FastAPI()
19
+ app.ready = False
20
+
21
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
22
+ tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_lfqa')
23
+ model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_lfqa').to(device)
24
+ _ = model.eval()
25
+
26
+
27
+ class JWTBearer(HTTPBearer):
28
+ def __init__(self, auto_error: bool = True):
29
+ super(JWTBearer, self).__init__(auto_error=auto_error)
30
+
31
+ async def __call__(self, request: Request):
32
+ credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
33
+ if credentials:
34
+ if not credentials.scheme == "Bearer":
35
+ raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
36
+ if not self.verify_jwt(credentials.credentials):
37
+ raise HTTPException(status_code=403, detail="Invalid token or expired token.")
38
+ return credentials.credentials
39
+ else:
40
+ raise HTTPException(status_code=403, detail="Invalid authorization code.")
41
+
42
+ def verify_jwt(self, jwtoken: str) -> bool:
43
+ isTokenValid: bool = False
44
+
45
+ try:
46
+ payload = decodeJWT(jwtoken)
47
+ except:
48
+ payload = None
49
+ if payload:
50
+ isTokenValid = True
51
+ return isTokenValid
52
+
53
+
54
+ def token_response(token: str):
55
+ return {
56
+ "access_token": token
57
+ }
58
+
59
+
60
+ def signJWT(user_id: str) -> Dict[str, str]:
61
+ payload = {
62
+ "user_id": user_id,
63
+ "expires": time.time() + 6000
64
+ }
65
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
66
+
67
+ return token_response(token)
68
+
69
+
70
+ def decodeJWT(token: str) -> dict:
71
+ try:
72
+ decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
73
+ return decoded_token if decoded_token["expires"] >= time.time() else None
74
+ except:
75
+ return {}
76
+
77
+
78
+ class LFQAParameters(BaseModel):
79
+ min_length: int = 50
80
+ max_length: int = 250
81
+ do_sample: bool = False
82
+ early_stopping: bool = True
83
+ num_beams: int = 8
84
+ temperature: float = 1.0
85
+ top_k: float = None
86
+ top_p: float = None
87
+ no_repeat_ngram_size: int = 3
88
+ num_return_sequences: int = 1
89
+
90
+
91
+ class InferencePayload(BaseModel):
92
+ model_input: str
93
+ parameters: Optional[LFQAParameters] = LFQAParameters()
94
+
95
+
96
+ @app.on_event("startup")
97
+ def startup():
98
+ app.ready = True
99
+
100
+
101
+ @app.get("/healthz")
102
+ def healthz():
103
+ if app.ready:
104
+ return PlainTextResponse("ok")
105
+ return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
106
+
107
+
108
+ @app.post("/generate/", dependencies=[Depends(JWTBearer())])
109
+ def generate(context: InferencePayload):
110
+
111
+ model_input = tokenizer(context.model_input, truncation=True, padding=True, return_tensors="pt")
112
+ param = context.parameters
113
+ generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
114
+ attention_mask=model_input["attention_mask"].to(device),
115
+ min_length=param.min_length,
116
+ max_length=param.max_length,
117
+ do_sample=param.do_sample,
118
+ early_stopping=param.early_stopping,
119
+ num_beams=param.num_beams,
120
+ temperature=param.temperature,
121
+ top_k=param.top_k,
122
+ top_p=param.top_p,
123
+ no_repeat_ngram_size=param.no_repeat_ngram_size,
124
+ num_return_sequences=param.num_return_sequences)
125
+ answers = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
126
+ clean_up_tokenization_spaces=True)
127
+ results = []
128
+ for answer in answers:
129
+ results.append({"generated_text": answer})
130
+ return results
lfqa_server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ datasets
2
+ transformers
3
+ fastapi
4
+ faiss-gpu
5
+ uvicorn[standard]
6
+ PyJWT==1.7.1
7
+ python-decouple==3.3