Achyut Tiwari commited on
Commit
da74da1
0 Parent(s):

Add files via upload

Browse files
context_server/Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
2
+ #set up environment
3
+ RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y curl
4
+ RUN apt-get install unzip
5
+ RUN apt-get -y install python3
6
+ RUN apt-get -y install python3-pip
7
+
8
+ WORKDIR /code
9
+
10
+ ENV HF_HOME=/code/cache
11
+
12
+ COPY ./requirements.txt /code/requirements.txt
13
+
14
+ RUN pip3 install --pre torch -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
15
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
16
+
17
+ COPY ./main.py /code/app/main.py
18
+
19
+ COPY ./data/kilt_wiki_prepared/ /code/data/kilt_wiki_prepared
20
+
21
+ COPY ./data/kilt_wikipedia.faiss /code/data/kilt_wikipedia.faiss
22
+
23
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
context_server/__init__.py ADDED
File without changes
context_server/main.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Depends, status
3
+ from fastapi.responses import PlainTextResponse
4
+ from transformers import AutoTokenizer, AutoModel, DPRQuestionEncoder
5
+
6
+ from datasets import load_from_disk
7
+ import time
8
+ from typing import Dict
9
+
10
+ import jwt
11
+ from decouple import config
12
+ from fastapi import Request, HTTPException
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
+
15
+ JWT_SECRET = config("secret")
16
+ JWT_ALGORITHM = config("algorithm")
17
+
18
+ app = FastAPI()
19
+ app.ready = False
20
+ columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
21
+ 'wikidata_info', 'history']
22
+
23
+ min_snippet_length = 20
24
+ topk = 21
25
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
26
+ model = DPRQuestionEncoder.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki").to(device)
27
+ tokenizer = AutoTokenizer.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki")
28
+ _ = model.eval()
29
+
30
+ index_file_name = "./data/kilt_wikipedia.faiss"
31
+
32
+ kilt_wikipedia_paragraphs = load_from_disk("./data/kilt_wiki_prepared")
33
+ # use paragraphs that are not simple fragments or very short sentences
34
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 200)
35
+
36
+
37
+ class JWTBearer(HTTPBearer):
38
+ def __init__(self, auto_error: bool = True):
39
+ super(JWTBearer, self).__init__(auto_error=auto_error)
40
+
41
+ async def __call__(self, request: Request):
42
+ credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
43
+ if credentials:
44
+ if not credentials.scheme == "Bearer":
45
+ raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
46
+ if not self.verify_jwt(credentials.credentials):
47
+ raise HTTPException(status_code=403, detail="Invalid token or expired token.")
48
+ return credentials.credentials
49
+ else:
50
+ raise HTTPException(status_code=403, detail="Invalid authorization code.")
51
+
52
+ def verify_jwt(self, jwtoken: str) -> bool:
53
+ isTokenValid: bool = False
54
+
55
+ try:
56
+ payload = decodeJWT(jwtoken)
57
+ except:
58
+ payload = None
59
+ if payload:
60
+ isTokenValid = True
61
+ return isTokenValid
62
+
63
+
64
+ def token_response(token: str):
65
+ return {
66
+ "access_token": token
67
+ }
68
+
69
+
70
+ def signJWT(user_id: str) -> Dict[str, str]:
71
+ payload = {
72
+ "user_id": user_id,
73
+ "expires": time.time() + 6000
74
+ }
75
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
76
+
77
+ return token_response(token)
78
+
79
+
80
+ def decodeJWT(token: str) -> dict:
81
+ try:
82
+ decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
83
+ return decoded_token if decoded_token["expires"] >= time.time() else None
84
+ except:
85
+ return {}
86
+
87
+
88
+ def embed_questions_for_retrieval(questions):
89
+ query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
90
+ with torch.no_grad():
91
+ q_reps = model(query["input_ids"].to(device), query["attention_mask"].to(device)).pooler_output
92
+ return q_reps.cpu().numpy()
93
+
94
+ def query_index(question):
95
+ question_embedding = embed_questions_for_retrieval([question])
96
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
97
+ columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id',
98
+ 'start_character', 'end_character']
99
+ retrieved_examples = []
100
+ r = list(zip(wiki_passages[k] for k in columns))
101
+ for i in range(topk):
102
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
103
+ return retrieved_examples
104
+
105
+
106
+ @app.on_event("startup")
107
+ def startup():
108
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
109
+ app.ready = True
110
+
111
+
112
+ @app.get("/healthz")
113
+ def healthz():
114
+ if app.ready:
115
+ return PlainTextResponse("ok")
116
+ return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
117
+
118
+
119
+ @app.get("/find_context", dependencies=[Depends(JWTBearer())])
120
+ def find_context(question: str = None):
121
+ return [res for res in query_index(question) if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
122
+
context_server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets
2
+ transformers
3
+ fastapi
4
+ faiss-gpu
5
+ uvicorn[standard]
6
+ PyJWT==1.7.1
7
+ python-decouple==3.3