mishig HF staff commited on
Commit
6949114
1 Parent(s): a342b03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -4,15 +4,22 @@ import torch.nn.functional as F
4
  import hnswlib
5
  import gradio as gr
6
  import numpy as np
 
 
7
 
8
  seperator = "-HFSEP-"
9
  base_name="intfloat/e5-large-v2"
10
  device="cuda"
11
  max_length=512
 
12
  tokenizer = AutoTokenizer.from_pretrained(base_name)
13
  model = AutoModel.from_pretrained(base_name).to(device)
14
 
 
 
 
15
  def get_embeddings(input_texts):
 
16
  batch_dict = tokenizer(
17
  input_texts,
18
  max_length=max_length,
@@ -52,16 +59,18 @@ def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16):
52
  def gradio_function(query, paragraph_chunks, top_k):
53
  paragraph_chunks = paragraph_chunks.split(seperator) # Split the comma-separated values into a list
54
  paragraph_chunks = [item.strip() for item in paragraph_chunks] # Trim whitespace from each item
 
55
 
56
- print("creating embeddings")
57
  embeddings_np = get_embeddings([query]+paragraph_chunks)
58
  query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:]
59
 
60
- print("creating index")
61
  search_index = create_hnsw_index(chunks_embeddings)
62
- print("searching index")
63
  labels, _ = search_index.knn_query(query_embedding, k=min(int(top_k), len(chunks_embeddings)))
64
- return f"The closes labels are: {labels}"
 
65
 
66
  interface = gr.Interface(
67
  fn=gradio_function,
@@ -73,4 +82,4 @@ interface = gr.Interface(
73
  outputs="text"
74
  )
75
 
76
- interface.launch()
 
4
  import hnswlib
5
  import gradio as gr
6
  import numpy as np
7
+ import json
8
+ import datetime
9
 
10
  seperator = "-HFSEP-"
11
  base_name="intfloat/e5-large-v2"
12
  device="cuda"
13
  max_length=512
14
+ max_batch_size = 500
15
  tokenizer = AutoTokenizer.from_pretrained(base_name)
16
  model = AutoModel.from_pretrained(base_name).to(device)
17
 
18
+ def current_timestamp():
19
+ return datetime.datetime.utcnow().timestamp()
20
+
21
  def get_embeddings(input_texts):
22
+ input_texts = input_texts[:max_batch_size]
23
  batch_dict = tokenizer(
24
  input_texts,
25
  max_length=max_length,
 
59
  def gradio_function(query, paragraph_chunks, top_k):
60
  paragraph_chunks = paragraph_chunks.split(seperator) # Split the comma-separated values into a list
61
  paragraph_chunks = [item.strip() for item in paragraph_chunks] # Trim whitespace from each item
62
+ print("Len of batches", len(paragraph_chunks))
63
 
64
+ print("creating embeddings", current_timestamp())
65
  embeddings_np = get_embeddings([query]+paragraph_chunks)
66
  query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:]
67
 
68
+ print("creating index", current_timestamp())
69
  search_index = create_hnsw_index(chunks_embeddings)
70
+ print("searching index", current_timestamp())
71
  labels, _ = search_index.knn_query(query_embedding, k=min(int(top_k), len(chunks_embeddings)))
72
+ labels = labels[0].tolist()
73
+ return json.dumps(labels)
74
 
75
  interface = gr.Interface(
76
  fn=gradio_function,
 
82
  outputs="text"
83
  )
84
 
85
+ interface.launch()