abhishekdileep commited on
Commit
f730c6b
·
1 Parent(s): 149f842

model loads twice sometimes wiht st.cache_resources

Browse files
Files changed (3) hide show
  1. app.py +3 -6
  2. model.py +8 -8
  3. rag.configs.yml +1 -2
app.py CHANGED
@@ -28,8 +28,6 @@ if "messages" not in st.session_state:
28
  if "doc" not in st.session_state:
29
  st.session_state.doc = None
30
 
31
- if "refresh" not in st.session_state:
32
- st.session_state.refresh = True
33
  if "refresh" not in st.session_state:
34
  st.session_state.refresh = True
35
 
@@ -42,19 +40,18 @@ if prompt := st.chat_input("Search Here insetad of Google"):
42
  st.chat_message("user").markdown(prompt)
43
  st.session_state.messages.append({"role": "user", "content": prompt})
44
 
 
45
  if st.session_state.refresh:
46
  st.session_state.refresh = False
47
  search(prompt)
48
 
49
  s = SemanticSearch(
50
  st.session_state.doc,
51
- st.session_state.st.session_state.configs["model"]["embeding_model"],
52
- st.session_state.st.session_state.configs["model"]["device"],
53
  )
54
  topk, u = s.semantic_search(query=prompt, k=32)
55
  output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
56
- topk, u = s.semantic_search(query=prompt, k=32)
57
- output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
58
  response = output
59
  with st.chat_message("assistant"):
60
  st.markdown(response)
 
28
  if "doc" not in st.session_state:
29
  st.session_state.doc = None
30
 
 
 
31
  if "refresh" not in st.session_state:
32
  st.session_state.refresh = True
33
 
 
40
  st.chat_message("user").markdown(prompt)
41
  st.session_state.messages.append({"role": "user", "content": prompt})
42
 
43
+ configs = st.session_state.configs
44
  if st.session_state.refresh:
45
  st.session_state.refresh = False
46
  search(prompt)
47
 
48
  s = SemanticSearch(
49
  st.session_state.doc,
50
+ configs["model"]["embeding_model"],
51
+ configs["model"]["device"],
52
  )
53
  topk, u = s.semantic_search(query=prompt, k=32)
54
  output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
 
 
55
  response = output
56
  with st.chat_message("assistant"):
57
  st.markdown(response)
model.py CHANGED
@@ -68,13 +68,13 @@ class RAGModel:
68
 
69
  if __name__ == "__main__":
70
  configs = load_configs(config_file="rag.configs.yml")
71
- query = "Explain F1 racing for a beginer"
72
- g = GoogleSearch(query)
73
- data = g.all_page_data
74
- d = Document(data, 512)
75
- doc_chunks = d.doc()
76
- s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
77
- topk, u = s.semantic_search(query=query, k=32)
78
  r = RAGModel(configs)
79
- output = r.answer_query(query=query, topk_items=topk)
80
  print(output)
 
68
 
69
  if __name__ == "__main__":
70
  configs = load_configs(config_file="rag.configs.yml")
71
+ query = "The height of burj khalifa is 1000 meters and it was built in 2023. What is the height of burgj khalifa"
72
+ # g = GoogleSearch(query)
73
+ # data = g.all_page_data
74
+ # d = Document(data, 512)
75
+ # doc_chunks = d.doc()
76
+ # s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
77
+ # topk, u = s.semantic_search(query=query, k=32)
78
  r = RAGModel(configs)
79
+ output = r.answer_query(query=query, topk_items=[""])
80
  print(output)
rag.configs.yml CHANGED
@@ -4,5 +4,4 @@ document:
4
  model:
5
  embeding_model: all-mpnet-base-v2
6
  genration_model: google/gemma-7b-it
7
- device : cuda
8
-
 
4
  model:
5
  embeding_model: all-mpnet-base-v2
6
  genration_model: google/gemma-7b-it
7
+ device : cuda