Add instructor model

#5
by awinml - opened
Files changed (2) hide show
  1. app.py +25 -9
  2. utils/vector_index.py +5 -2
app.py CHANGED
@@ -290,6 +290,7 @@ elif encoder_model == "Instructor":
290
  pinecone_index_name = "week13-instructor-xl"
291
  pinecone_index = pinecone.Index(pinecone_index_name)
292
  retriever_model = get_instructor_embedding_model()
 
293
 
294
  elif encoder_model == "Hybrid MPNET - SPLADE":
295
  pinecone.init(
@@ -354,9 +355,14 @@ if document_type == "Single-Document":
354
  )
355
 
356
  else:
357
- dense_query_embedding = create_dense_embeddings(
358
- query_text, retriever_model
359
- )
 
 
 
 
 
360
  query_results = query_pinecone(
361
  dense_query_embedding,
362
  num_results,
@@ -410,9 +416,14 @@ else:
410
  context_group.append((results_list, year, quarter, ticker))
411
 
412
  else:
413
- dense_query_embedding = create_dense_embeddings(
414
- query_text, retriever_model
415
- )
 
 
 
 
 
416
  year_quarter_list = year_quarter_range(
417
  start_quarter, start_year, end_quarter, end_year
418
  )
@@ -494,9 +505,14 @@ else:
494
  )
495
 
496
  else:
497
- dense_query_embedding = create_dense_embeddings(
498
- query_text, retriever_model
499
- )
 
 
 
 
 
500
  year_quarter_list = year_quarter_range(
501
  start_quarter, start_year, end_quarter, end_year
502
  )
 
290
  pinecone_index_name = "week13-instructor-xl"
291
  pinecone_index = pinecone.Index(pinecone_index_name)
292
  retriever_model = get_instructor_embedding_model()
293
+ instruction = "Represent the financial question for retrieving supporting documents:"
294
 
295
  elif encoder_model == "Hybrid MPNET - SPLADE":
296
  pinecone.init(
 
355
  )
356
 
357
  else:
358
+ if encoder_model == "Instructor":
359
+ dense_query_embedding = create_dense_embeddings(
360
+ query_text, retriever_model, instruction
361
+ )
362
+ else:
363
+ dense_query_embedding = create_dense_embeddings(
364
+ query_text, retriever_model
365
+ )
366
  query_results = query_pinecone(
367
  dense_query_embedding,
368
  num_results,
 
416
  context_group.append((results_list, year, quarter, ticker))
417
 
418
  else:
419
+ if encoder_model == "Instructor":
420
+ dense_query_embedding = create_dense_embeddings(
421
+ query_text, retriever_model, instruction
422
+ )
423
+ else:
424
+ dense_query_embedding = create_dense_embeddings(
425
+ query_text, retriever_model
426
+ )
427
  year_quarter_list = year_quarter_range(
428
  start_quarter, start_year, end_quarter, end_year
429
  )
 
505
  )
506
 
507
  else:
508
+ if encoder_model == "Instructor":
509
+ dense_query_embedding = create_dense_embeddings(
510
+ query_text, retriever_model, instruction
511
+ )
512
+ else:
513
+ dense_query_embedding = create_dense_embeddings(
514
+ query_text, retriever_model
515
+ )
516
  year_quarter_list = year_quarter_range(
517
  start_quarter, start_year, end_quarter, end_year
518
  )
utils/vector_index.py CHANGED
@@ -1,8 +1,11 @@
1
  import torch
2
 
3
 
4
- def create_dense_embeddings(query, model):
5
- dense_emb = model.encode([query]).tolist()
 
 
 
6
  return dense_emb
7
 
8
 
 
1
  import torch
2
 
3
 
4
+ def create_dense_embeddings(query, model, instruction=None):
5
+ if instruction == None:
6
+ dense_emb = model.encode([query]).tolist()
7
+ else:
8
+ dense_emb = model.encoder([[instruction, query]]).tolist()
9
  return dense_emb
10
 
11