tomiwa1a commited on
Commit
766b395
1 Parent(s): 93a37ae

add generate_answer for long form question answering https://github.com/atilatech/atila-core-service/pull/7

Browse files

see: https://github.com/atilatech/atila-core-service/pull/7


![haystack-lfqa-1.png](https://s3.amazonaws.com/moonup/production/uploads/1674311706276-63a4969d658851481f7729dd.png)

Files changed (1) hide show
  1. handler.py +55 -4
handler.py CHANGED
@@ -3,6 +3,7 @@ from typing import Dict
3
  from sentence_transformers import SentenceTransformer
4
  from tqdm import tqdm
5
  import whisper
 
6
  import torch
7
  import pytube
8
  import time
@@ -12,11 +13,13 @@ class EndpointHandler():
12
  # load the model
13
  WHISPER_MODEL_NAME = "tiny.en"
14
  SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
 
 
15
 
16
  def __init__(self, path=""):
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
- print(f'whisper will use: {device}')
20
 
21
  t0 = time.time()
22
  self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
@@ -31,6 +34,13 @@ class EndpointHandler():
31
 
32
  total = t1 - t0
33
  print(f'Finished loading sentence_transformer_model in {total} seconds')
 
 
 
 
 
 
 
34
 
35
  def __call__(self, data: Dict[str, str]) -> Dict:
36
  """
@@ -48,6 +58,7 @@ class EndpointHandler():
48
  f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
49
  video_url = data.pop("video_url", None)
50
  query = data.pop("query", None)
 
51
  encoded_segments = {}
52
  if video_url:
53
  video_with_transcript = self.transcribe_video(video_url)
@@ -63,11 +74,27 @@ class EndpointHandler():
63
  **encoded_segments
64
  }
65
  elif query:
66
- query = [{"text": query, "id": ""}] if isinstance(query, str) else query
67
- encoded_segments = self.encode_sentences(query)
 
 
 
 
 
 
 
 
 
68
 
 
 
 
 
 
 
 
69
  return {
70
- "encoded_segments": encoded_segments
71
  }
72
 
73
  def transcribe_video(self, video_url):
@@ -140,6 +167,30 @@ class EndpointHandler():
140
 
141
  return all_batches
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  @staticmethod
144
  def combine_transcripts(video, window=6, stride=3):
145
  """
 
3
  from sentence_transformers import SentenceTransformer
4
  from tqdm import tqdm
5
  import whisper
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
8
  import pytube
9
  import time
 
13
  # load the model
14
  WHISPER_MODEL_NAME = "tiny.en"
15
  SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
16
+ QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
 
19
  def __init__(self, path=""):
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print(f'whisper and question_answer_model will use: {device}')
23
 
24
  t0 = time.time()
25
  self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
 
34
 
35
  total = t1 - t0
36
  print(f'Finished loading sentence_transformer_model in {total} seconds')
37
+
38
+ self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
39
+ t0 = time.time()
40
+ self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME).to(device)
41
+ t1 = time.time()
42
+ total = t1 - t0
43
+ print(f'Finished loading question_answer_model in {total} seconds')
44
 
45
  def __call__(self, data: Dict[str, str]) -> Dict:
46
  """
 
58
  f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
59
  video_url = data.pop("video_url", None)
60
  query = data.pop("query", None)
61
+ long_form_answer = data.pop("long_form_answer", None)
62
  encoded_segments = {}
63
  if video_url:
64
  video_with_transcript = self.transcribe_video(video_url)
 
74
  **encoded_segments
75
  }
76
  elif query:
77
+ if long_form_answer:
78
+ context = data.pop("context", None)
79
+ answer = self.generate_answer(query, context)
80
+ response = {
81
+ "answer": answer
82
+ }
83
+
84
+ return response
85
+ else:
86
+ query = [{"text": query, "id": ""}] if isinstance(query, str) else query
87
+ encoded_segments = self.encode_sentences(query)
88
 
89
+ response = {
90
+ "encoded_segments": encoded_segments
91
+ }
92
+
93
+ return response
94
+
95
+ else:
96
  return {
97
+ "error": "'video_url' or 'query' must be provided"
98
  }
99
 
100
  def transcribe_video(self, video_url):
 
167
 
168
  return all_batches
169
 
170
+ def generate_answer(self, query, documents):
171
+
172
+ # concatenate question and support documents into BART input
173
+ conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
174
+ query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
175
+
176
+ model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True, return_tensors="pt")
177
+
178
+ generated_answers_encoded = self.question_answer_model.generate(input_ids=model_input["input_ids"].to(self.device),
179
+ attention_mask=model_input["attention_mask"].to(self.device),
180
+ min_length=64,
181
+ max_length=256,
182
+ do_sample=False,
183
+ early_stopping=True,
184
+ num_beams=8,
185
+ temperature=1.0,
186
+ top_k=None,
187
+ top_p=None,
188
+ eos_token_id=self.question_answer_tokenizer.eos_token_id,
189
+ no_repeat_ngram_size=3,
190
+ num_return_sequences=1)
191
+ answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)
192
+ return answer
193
+
194
  @staticmethod
195
  def combine_transcripts(video, window=6, stride=3):
196
  """