nickmuchi commited on
Commit
e1f6c5c
1 Parent(s): e232116

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +4 -95
functions.py CHANGED
@@ -119,15 +119,6 @@ def load_asr_model(asr_model_name):
119
  asr_model = whisper.load_model(asr_model_name)
120
 
121
  return asr_model
122
-
123
- # @st.experimental_singleton(suppress_st_warning=True)
124
- # def load_sbert(model_name):
125
- # if 'hkunlp' in model_name:
126
- # sbert = INSTRUCTOR(model_name)
127
- # else:
128
- # sbert = SentenceTransformer(model_name)
129
-
130
- # return sbert
131
 
132
  @st.experimental_singleton(suppress_st_warning=True)
133
  def process_corpus(corpus, tok, title, embeddings, chunk_size=200, overlap=50):
@@ -185,7 +176,7 @@ def embed_text(query,corpus,title,embedding_model,emb_tok,chain_type='stuff'):
185
 
186
  docs = [d[0] for d in docs]
187
 
188
- if chain_type == 'stuff':
189
 
190
  PROMPT = PromptTemplate(template=template,
191
  input_variables=["summaries", "question"],
@@ -200,7 +191,7 @@ def embed_text(query,corpus,title,embedding_model,emb_tok,chain_type='stuff'):
200
 
201
  return answer['output_text']
202
 
203
- elif chain_type == 'refine':
204
 
205
  initial_qa_prompt = PromptTemplate(
206
  input_variables=["context_str", "question"], template=initial_qa_template
@@ -211,62 +202,6 @@ def embed_text(query,corpus,title,embedding_model,emb_tok,chain_type='stuff'):
211
 
212
  return answer['output_text']
213
 
214
- # @st.experimental_memo(suppress_st_warning=True)
215
- # def embed_text(query,corpus,embedding_model):
216
-
217
- # '''Embed text and generate semantic search scores'''
218
-
219
- # #If model is e5 then apply prefixes to query and passage
220
- # if embedding_model == 'intfloat/e5-base':
221
- # search_input = 'query: '+ query
222
- # passages_emb = ['passage: ' + sentence for sentence in corpus]
223
-
224
- # elif embedding_model == 'hkunlp/instructor-base':
225
- # search_input = [['Represent the Financial question for retrieving supporting paragraphs: ', query]]
226
- # passages_emb = [['Represent the Financial paragraph for retrieval: ',sentence] for sentence in corpus]
227
-
228
- # else:
229
- # search_input = query
230
- # passages_emb = corpus
231
-
232
-
233
- # #Embed corpus and question
234
- # corpus_embedding = sbert.encode(passages_emb, convert_to_tensor=True)
235
- # question_embedding = sbert.encode(search_input, convert_to_tensor=True)
236
- # question_embedding = question_embedding.cpu()
237
- # corpus_embedding = corpus_embedding.cpu()
238
-
239
- # # #Calculate similarity scores and rank
240
- # hits = util.semantic_search(question_embedding, corpus_embedding, top_k=2)
241
- # hits = hits[0] # Get the hits for the first query
242
-
243
- # # ##### Re-Ranking #####
244
- # # Now, score all retrieved passages with the cross_encoder
245
- # cross_inp = [[search_input, corpus[hit['corpus_id']]] for hit in hits]
246
-
247
- # if embedding_model == 'hkunlp/instructor-base':
248
- # result = []
249
-
250
- # for sublist in cross_inp:
251
- # question = sublist[0][0][1]
252
- # document = sublist[1][1]
253
- # result.append([question, document])
254
-
255
- # cross_inp = result
256
-
257
- # cross_scores = cross_encoder.predict(cross_inp)
258
-
259
- # # Sort results by the cross-encoder scores
260
- # for idx in range(len(cross_scores)):
261
- # hits[idx]['cross-score'] = cross_scores[idx]
262
-
263
- # # Output of top-3 hits from re-ranker
264
- # # st.markdown("\n-------------------------\n")
265
- # # st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
266
- # hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
267
-
268
- # return hits
269
-
270
  @st.experimental_singleton(suppress_st_warning=True)
271
  def get_spacy():
272
  nlp = en_core_web_lg.load()
@@ -350,32 +285,7 @@ def chunk_long_text(text,threshold,window_size=3,stride=2):
350
  end_idx = min(start_idx+window_size, len(paragraph))
351
  passages.append(" ".join(paragraph[start_idx:end_idx]))
352
 
353
- return passages
354
-
355
- @st.experimental_memo(suppress_st_warning=True)
356
- def chunk_and_preprocess_text(text,thresh=500):
357
-
358
- """Chunk text longer than n tokens for summarization"""
359
-
360
- sentences = sent_tokenize(text)
361
-
362
- current_chunk = 0
363
- chunks = []
364
-
365
- for sentence in sentences:
366
- if len(chunks) == current_chunk + 1:
367
- if len(chunks[current_chunk]) + len(sentence.split(" ")) <= thresh:
368
- chunks[current_chunk].extend(sentence.split(" "))
369
- else:
370
- current_chunk += 1
371
- chunks.append(sentence.split(" "))
372
- else:
373
- chunks.append(sentence.split(" "))
374
-
375
- for chunk_id in range(len(chunks)):
376
- chunks[chunk_id] = " ".join(chunks[chunk_id])
377
-
378
- return chunks
379
 
380
 
381
  def summary_downloader(raw_text):
@@ -830,5 +740,4 @@ def save_network_html(kb, filename="network.html"):
830
 
831
 
832
  nlp = get_spacy()
833
- sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer = load_models()
834
- sbert = load_sbert('all-MiniLM-L12-v2')
 
119
  asr_model = whisper.load_model(asr_model_name)
120
 
121
  return asr_model
 
 
 
 
 
 
 
 
 
122
 
123
  @st.experimental_singleton(suppress_st_warning=True)
124
  def process_corpus(corpus, tok, title, embeddings, chunk_size=200, overlap=50):
 
176
 
177
  docs = [d[0] for d in docs]
178
 
179
+ if chain_type == 'Normal':
180
 
181
  PROMPT = PromptTemplate(template=template,
182
  input_variables=["summaries", "question"],
 
191
 
192
  return answer['output_text']
193
 
194
+ elif chain_type == 'Refined':
195
 
196
  initial_qa_prompt = PromptTemplate(
197
  input_variables=["context_str", "question"], template=initial_qa_template
 
202
 
203
  return answer['output_text']
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  @st.experimental_singleton(suppress_st_warning=True)
206
  def get_spacy():
207
  nlp = en_core_web_lg.load()
 
285
  end_idx = min(start_idx+window_size, len(paragraph))
286
  passages.append(" ".join(paragraph[start_idx:end_idx]))
287
 
288
+ return passages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
 
291
  def summary_downloader(raw_text):
 
740
 
741
 
742
  nlp = get_spacy()
743
+ sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer = load_models()