alvinhenrick commited on
Commit
a7d44c2
1 Parent(s): 99db4cb

switch sematic caching to all-mpnet-base-v2

Browse files
app.py CHANGED
@@ -20,7 +20,7 @@ turbo = dspy.OpenAI(model='gpt-4o')
20
  dspy.settings.configure(lm=turbo, rm=rm)
21
 
22
  rag = RAG(k=5)
23
- sm = SemanticCaching(model_name='dmis-lab/biobert-base-cased-v1.2', dimension=768,
24
  json_file='rag_test_cache.json', cosine_threshold=.85, rag=rag)
25
  sm.load_cache()
26
 
 
20
  dspy.settings.configure(lm=turbo, rm=rm)
21
 
22
  rag = RAG(k=5)
23
+ sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
24
  json_file='rag_test_cache.json', cosine_threshold=.85, rag=rag)
25
  sm.load_cache()
26
 
medirag/cache/local.py CHANGED
@@ -8,7 +8,7 @@ from medirag.rag.qa import RAG
8
 
9
  class SemanticCaching:
10
  def __init__(self,
11
- model_name='dmis-lab/biobert-base-cased-v1.2',
12
  dimension=768,
13
  json_file='cache.json',
14
  cosine_threshold=0.7,
 
8
 
9
  class SemanticCaching:
10
  def __init__(self,
11
+ model_name='sentence-transformers/all-mpnet-base-v2',
12
  dimension=768,
13
  json_file='cache.json',
14
  cosine_threshold=0.7,
tests/cache/test_semantic_cache.py CHANGED
@@ -7,7 +7,7 @@ from medirag.cache.local import SemanticCaching
7
  @pytest.fixture(scope="module")
8
  def semantic_caching():
9
  # This will actually initialize the model and the index
10
- return SemanticCaching(model_name='dmis-lab/biobert-base-cased-v1.2', dimension=768,
11
  json_file='real_test_cache.json')
12
 
13
 
 
7
  @pytest.fixture(scope="module")
8
  def semantic_caching():
9
  # This will actually initialize the model and the index
10
+ return SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
11
  json_file='real_test_cache.json')
12
 
13
 
tests/rag/test_rag.py CHANGED
@@ -26,7 +26,7 @@ def test_rag_with_example(data_dir):
26
 
27
  rag = RAG(k=3)
28
 
29
- sm = SemanticCaching(model_name='dmis-lab/biobert-base-cased-v1.2', dimension=768,
30
  json_file='rag_test_cache.json', rag=rag)
31
  sm.load_cache()
32
 
 
26
 
27
  rag = RAG(k=3)
28
 
29
+ sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
30
  json_file='rag_test_cache.json', rag=rag)
31
  sm.load_cache()
32