SebastianSchramm commited on
Commit
fca97ef
1 Parent(s): 053ffc5

use models

Browse files
Files changed (2) hide show
  1. data/paris-2024-faq.json +0 -0
  2. server.py +98 -3
data/paris-2024-faq.json ADDED
The diff for this file is too large to render. See raw diff
 
server.py CHANGED
@@ -1,14 +1,32 @@
1
  import logging
 
 
 
 
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
 
 
 
5
 
6
 
 
 
7
  logging.basicConfig()
8
  logger = logging.getLogger(__name__)
9
  logger.setLevel(logging.INFO)
10
 
11
 
 
 
 
 
 
 
 
 
 
12
  class InputLoad(BaseModel):
13
  question: str
14
 
@@ -17,7 +35,31 @@ class ResponseLoad(BaseModel):
17
  answer: str
18
 
19
 
20
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  @app.get("/health")
@@ -26,5 +68,58 @@ def health_check():
26
 
27
 
28
  @app.post("/answer/")
29
- async def receive(input_load: InputLoad) -> ResponseLoad:
30
- return ResponseLoad(answer="Hi, happy to help you with that. According to my information this is possible! Hope that was helpful!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ import json
3
+ from contextlib import asynccontextmanager
4
+ from typing import Any, List, Tuple
5
+ import random
6
 
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
+ from FlagEmbedding import BGEM3FlagModel, FlagReranker
10
+ from starlette.requests import Request
11
+ import torch
12
 
13
 
14
+ random.seed(42)
15
+
16
  logging.basicConfig()
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.INFO)
19
 
20
 
21
+ def get_data(model):
22
+ with open("data/paris-2024-faq.json") as f:
23
+ data = json.load(f)
24
+ data = [it for it in data if it['lang'] == 'en']
25
+ questions = [it['label'] for it in data]
26
+ q_embeddings = model[0].encode(questions, return_dense=False, return_sparse=False, return_colbert_vecs=True)
27
+ return q_embeddings['colbert_vecs'], questions, [it['body'] for it in data]
28
+
29
+
30
  class InputLoad(BaseModel):
31
  question: str
32
 
 
35
  answer: str
36
 
37
 
38
+ class ML(BaseModel):
39
+ retriever: Any
40
+ ranker: Any
41
+ data: Tuple[List[Any], List[str], List[str]]
42
+
43
+
44
+ def load_models(app: FastAPI) -> FastAPI:
45
+ retriever=BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) ,
46
+ ranker=FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
47
+ ml = ML(
48
+ retriever=retriever,
49
+ ranker=ranker,
50
+ data=get_data(retriever)
51
+ )
52
+ app.ml = ml
53
+ return app
54
+
55
+
56
+ @asynccontextmanager
57
+ async def lifespan(app: FastAPI):
58
+ app = load_models(app=app)
59
+ yield
60
+
61
+
62
+ app = FastAPI(lifespan=lifespan)
63
 
64
 
65
  @app.get("/health")
 
68
 
69
 
70
  @app.post("/answer/")
71
+ async def receive(input_load: InputLoad, request: Request) -> ResponseLoad:
72
+ ml: ML = request.app.ml
73
+ candidate_indices, candidate_scores = get_candidates(input_load.question, ml)
74
+ answer_candidate, rank_score, retriever_score = rerank_candidates(input_load.question, candidate_indices, candidate_scores, ml)
75
+ answer = get_final_answer(answer_candidate, retriever_score)
76
+ return ResponseLoad(answer=answer)
77
+
78
+
79
+ def get_candidates(question, ml, topk=5):
80
+ question_emb = ml.retriever[0].encode([question], return_dense=False, return_sparse=False, return_colbert_vecs=True)
81
+ question_emb = question_emb['colbert_vecs'][0]
82
+ scores = [ml.retriever[0].colbert_score(question_emb, faq_emb) for faq_emb in ml.data[0]]
83
+ scores_tensor = torch.stack(scores)
84
+ top_values, top_indices = torch.topk(scores_tensor, topk)
85
+ return top_indices.tolist(), top_values.tolist()
86
+
87
+
88
+ def rerank_candidates(question, indices, values, ml):
89
+ candidate_answers = [ml.data[2][_ind] for _ind in indices]
90
+ scores = ml.ranker.compute_score([[question, it] for it in candidate_answers])
91
+ rank_score = max(scores)
92
+ rank_ind = scores.index(rank_score)
93
+ retriever_score = values[rank_ind]
94
+ return candidate_answers[rank_ind], rank_score, retriever_score
95
+
96
+
97
+ def get_final_answer(answer, retriever_score):
98
+ logger.info(f"Retriever score: {retriever_score}")
99
+ if retriever_score < 0.65:
100
+ # nothing relevant found!
101
+ return random.sample(NOT_FOUND_ANSWERS, k=1)[0]
102
+ elif retriever_score < 0.8:
103
+ # might be relevant, but let's be careful
104
+ return f"{random.sample(ROUGH_MATCH_INTROS, k=1)[0]}\n{answer}"
105
+ else:
106
+ # good match
107
+ return f"{random.sample(GOOD_MATCH_INTROS, k=1)[0]}\n{answer}\n{random.sample(GOOD_MATCH_ENDS, k=1)[0]}"
108
+
109
+
110
+ NOT_FOUND_ANSWERS = [
111
+ "I'm sorry, but I couldn't find any information related to your question in my knowledge base.",
112
+ "Apologies, but I don't have the information you're looking for at the moment.",
113
+ "I’m sorry, I couldn’t locate any relevant details in my current data.",
114
+ "Unfortunately, I wasn't able to find an answer to your query. Can I help with something else?",
115
+ "I'm afraid I don't have the information you need right now. Please feel free to ask another question.",
116
+ "Sorry, I couldn't find anything that matches your question in my knowledge base.",
117
+ "I apologize, but I wasn't able to retrieve information related to your query.",
118
+ "I'm sorry, but it looks like I don't have an answer for that. Is there anything else I can assist with?",
119
+ "Regrettably, I couldn't find the information you requested. Can I help you with anything else?",
120
+ "I’m sorry, but I don't have the details you're seeking in my knowledge database."
121
+ ]
122
+
123
+ GOOD_MATCH_INTROS = ["Super!"]
124
+ GOOD_MATCH_ENDS = ["Hopes this helps!"]
125
+ ROUGH_MATCH_INTROS = ["Not sure if that answers your question!"]