derek-thomas HF staff commited on
Commit
9a4e478
1 Parent(s): 8de88bd

Minor fixes

Browse files
Files changed (2) hide show
  1. backend/query_llm.py +6 -7
  2. backend/semantic_search.py +20 -7
backend/query_llm.py CHANGED
@@ -1,17 +1,16 @@
1
  import os
2
-
3
- import requests
4
  from os import getenv
5
 
 
 
6
 
7
  API_URL = getenv('API_URL')
8
  BEARER = getenv('BEARER')
9
 
10
-
11
  headers = {
12
- "Authorization": f"Bearer {BEARER}",
13
- "Content-Type": "application/json"
14
- }
15
 
16
 
17
  def call_jais(payload):
@@ -26,7 +25,7 @@ def call_jais(payload):
26
 
27
 
28
  def generate(prompt: str):
29
- payload = {'inputs': '', 'prompt':prompt}
30
  response = call_jais(payload)
31
  return response
32
 
 
1
  import os
 
 
2
  from os import getenv
3
 
4
+ import gradio as gr
5
+ import requests
6
 
7
  API_URL = getenv('API_URL')
8
  BEARER = getenv('BEARER')
9
 
 
10
  headers = {
11
+ "Authorization": f"Bearer {BEARER}",
12
+ "Content-Type": "application/json"
13
+ }
14
 
15
 
16
  def call_jais(payload):
 
25
 
26
 
27
  def generate(prompt: str):
28
+ payload = {'inputs': '', 'prompt': prompt}
29
  response = call_jais(payload)
30
  return response
31
 
backend/semantic_search.py CHANGED
@@ -1,9 +1,10 @@
1
  import logging
2
- from pathlib import Path
3
  import time
 
4
 
5
  import lancedb
6
  from sentence_transformers import SentenceTransformer
 
7
  import spaces
8
 
9
 
@@ -17,7 +18,7 @@ start_time = time.perf_counter()
17
  proj_dir = Path(__file__).parents[1]
18
 
19
  # Log the time taken to load the QdrantDocumentStore
20
- db = lancedb.connect(proj_dir/"lancedb")
21
  tbl = db.open_table('arabic-wiki')
22
  lancedb_loading_time = time.perf_counter() - start_time
23
  logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
@@ -25,23 +26,35 @@ logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
25
  # Start the timer for loading the EmbeddingRetriever
26
  start_time = time.perf_counter()
27
 
28
- name="sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
29
- st_model = SentenceTransformer(name, device='cuda')
 
 
30
 
31
  # used for both training and querying
 
 
 
 
 
 
 
 
32
  @spaces.GPU
33
  def embed_func(query):
34
- return st_model.encode(query)
 
35
 
36
  def vector_search(query_vector, top_k):
37
  return tbl.search(query_vector).limit(top_k).to_list()
38
 
 
39
  def retriever(query, top_k=3):
40
- query_vector = embed_func(query)
41
  documents = vector_search(query_vector, top_k)
42
  return documents
43
 
44
 
45
  # Log the time taken to load the EmbeddingRetriever
46
  retriever_loading_time = time.perf_counter() - start_time
47
- logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")
 
1
  import logging
 
2
  import time
3
+ from pathlib import Path
4
 
5
  import lancedb
6
  from sentence_transformers import SentenceTransformer
7
+
8
  import spaces
9
 
10
 
 
18
  proj_dir = Path(__file__).parents[1]
19
 
20
  # Log the time taken to load the QdrantDocumentStore
21
+ db = lancedb.connect(proj_dir / "lancedb")
22
  tbl = db.open_table('arabic-wiki')
23
  lancedb_loading_time = time.perf_counter() - start_time
24
  logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
 
26
  # Start the timer for loading the EmbeddingRetriever
27
  start_time = time.perf_counter()
28
 
29
+ name = "sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
30
+ st_model_gpu = SentenceTransformer(name, device='mps')
31
+ st_model_cpu = SentenceTransformer(name, device='cpu')
32
+
33
 
34
  # used for both training and querying
35
+ def call_embed_func(query):
36
+ try:
37
+ return embed_func(query)
38
+ except:
39
+ logger.warning(f'Using CPU')
40
+ return st_model_cpu.encode(query)
41
+
42
+
43
  @spaces.GPU
44
  def embed_func(query):
45
+ return st_model_gpu.encode(query)
46
+
47
 
48
  def vector_search(query_vector, top_k):
49
  return tbl.search(query_vector).limit(top_k).to_list()
50
 
51
+
52
  def retriever(query, top_k=3):
53
+ query_vector = call_embed_func(query)
54
  documents = vector_search(query_vector, top_k)
55
  return documents
56
 
57
 
58
  # Log the time taken to load the EmbeddingRetriever
59
  retriever_loading_time = time.perf_counter() - start_time
60
+ logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")