Zwea Htet commited on
Commit
2abc521
1 Parent(s): 1c67be9

fixed f-string bugs

Browse files
Files changed (1) hide show
  1. pages/llama_custom_demo.py +17 -6
pages/llama_custom_demo.py CHANGED
@@ -11,7 +11,7 @@ from models.llms import (
11
  llm_gpt_3_5_turbo,
12
  llm_gpt_3_5_turbo_0125,
13
  llm_gpt_4_0125,
14
- llm_llama_13b_v2_replicate
15
  )
16
  from models.embeddings import hf_embed_model, openai_embed_model
17
  from models.llamaCustom import LlamaCustom
@@ -45,6 +45,7 @@ llama_llms = {
45
  # "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate,
46
  }
47
 
 
48
  def init_session_state():
49
  if "llama_messages" not in st.session_state:
50
  st.session_state.llama_messages = [
@@ -60,6 +61,7 @@ def init_session_state():
60
  if "llama_custom" not in st.session_state:
61
  st.session_state.llama_custom = None
62
 
 
63
  # @st.cache_resource
64
  def index_docs(
65
  filename: str,
@@ -73,7 +75,7 @@ def index_docs(
73
 
74
  # test the index
75
  index.as_query_engine().query("What is the capital of France?")
76
-
77
  else:
78
  reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
79
  docs = reader.load_data(show_progress=True)
@@ -81,7 +83,9 @@ def index_docs(
81
  documents=docs,
82
  show_progress=True,
83
  )
84
- index.storage_context.persist(persist_dir=f"vectorStores/{filename.replace(".", '_')}")
 
 
85
 
86
  except Exception as e:
87
  print(f"Error: {e}")
@@ -92,6 +96,7 @@ def index_docs(
92
  def load_llm(model_name: str):
93
  return llama_llms[model_name]
94
 
 
95
  init_session_state()
96
 
97
  st.set_page_config(page_title="Llama", page_icon="🦙")
@@ -102,7 +107,9 @@ tab1, tab2 = st.tabs(["Config", "Chat"])
102
 
103
  with tab1:
104
  with st.form(key="llama_form"):
105
- selected_llm_name = st.selectbox(label="Select a model:", options=llama_llms.keys())
 
 
106
 
107
  if selected_llm_name.startswith("openai"):
108
  # ask for the api key
@@ -140,7 +147,12 @@ with tab1:
140
  with tab2:
141
  messages_container = st.container(height=300)
142
  show_previous_messages(framework="llama", messages_container=messages_container)
143
- show_chat_input(disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container)
 
 
 
 
 
144
 
145
  def clear_history():
146
  messages_container.empty()
@@ -155,4 +167,3 @@ with tab2:
155
  if st.button("Clear Chat History"):
156
  clear_history()
157
  st.rerun()
158
-
 
11
  llm_gpt_3_5_turbo,
12
  llm_gpt_3_5_turbo_0125,
13
  llm_gpt_4_0125,
14
+ llm_llama_13b_v2_replicate,
15
  )
16
  from models.embeddings import hf_embed_model, openai_embed_model
17
  from models.llamaCustom import LlamaCustom
 
45
  # "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate,
46
  }
47
 
48
+
49
  def init_session_state():
50
  if "llama_messages" not in st.session_state:
51
  st.session_state.llama_messages = [
 
61
  if "llama_custom" not in st.session_state:
62
  st.session_state.llama_custom = None
63
 
64
+
65
  # @st.cache_resource
66
  def index_docs(
67
  filename: str,
 
75
 
76
  # test the index
77
  index.as_query_engine().query("What is the capital of France?")
78
+
79
  else:
80
  reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
81
  docs = reader.load_data(show_progress=True)
 
83
  documents=docs,
84
  show_progress=True,
85
  )
86
+ index.storage_context.persist(
87
+ persist_dir=f"vectorStores/{filename.replace('.', '_')}"
88
+ )
89
 
90
  except Exception as e:
91
  print(f"Error: {e}")
 
96
  def load_llm(model_name: str):
97
  return llama_llms[model_name]
98
 
99
+
100
  init_session_state()
101
 
102
  st.set_page_config(page_title="Llama", page_icon="🦙")
 
107
 
108
  with tab1:
109
  with st.form(key="llama_form"):
110
+ selected_llm_name = st.selectbox(
111
+ label="Select a model:", options=llama_llms.keys()
112
+ )
113
 
114
  if selected_llm_name.startswith("openai"):
115
  # ask for the api key
 
147
  with tab2:
148
  messages_container = st.container(height=300)
149
  show_previous_messages(framework="llama", messages_container=messages_container)
150
+ show_chat_input(
151
+ disabled=False,
152
+ framework="llama",
153
+ model=st.session_state.llama_custom,
154
+ messages_container=messages_container,
155
+ )
156
 
157
  def clear_history():
158
  messages_container.empty()
 
167
  if st.button("Clear Chat History"):
168
  clear_history()
169
  st.rerun()