Zwea Htet commited on
Commit
4bb745d
1 Parent(s): 3059501

fixed bugs in llama index custom demo and updated ui

Browse files
.gitignore CHANGED
@@ -4,4 +4,6 @@ models/__pycache__
4
  .env
5
  __pycache__
6
  vectorStores
7
- .vscode
 
 
 
4
  .env
5
  __pycache__
6
  vectorStores
7
+ .vscode
8
+ .streamlit/secrets.toml
9
+ uploaded_files
app.py CHANGED
@@ -8,23 +8,63 @@ import streamlit as st
8
 
9
  from utils.util import *
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  st.set_page_config(page_title="RegBotBeta", page_icon="📜🤖")
12
 
13
  st.title("Welcome to RegBotBeta2.0")
14
  st.header("Powered by `LlamaIndex🦙`, `Langchain🦜🔗 ` and `OpenAI API`")
15
 
16
- api_key = st.text_input("Enter your OpenAI API key here:", type="password")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- if api_key:
19
- resp = validate(api_key)
20
- if "error" in resp.json():
21
- st.info("Invalid Token! Try again.")
22
- else:
23
- st.info("Success")
24
- os.environ["OPENAI_API_KEY"] = api_key
25
- openai.api_key = api_key
26
 
27
- if "openai_api_key" not in st.session_state:
28
- st.session_state.openai_api_key = ""
29
 
30
- st.session_state.openai_api_key = api_key
 
 
 
 
8
 
9
  from utils.util import *
10
 
11
+ from langchain.memory import ConversationBufferMemory
12
+
13
+ SAVE_DIR = "uploaded_files"
14
+ os.makedirs(SAVE_DIR, exist_ok=True)
15
+
16
+
17
+ def init_session_state():
18
+ if "openai_api_key" not in st.session_state:
19
+ st.session_state.openai_api_key = ""
20
+
21
+ if "uploaded_files" not in st.session_state:
22
+ st.session_state.uploaded_files = os.listdir(SAVE_DIR)
23
+
24
+
25
+ init_session_state()
26
+
27
  st.set_page_config(page_title="RegBotBeta", page_icon="📜🤖")
28
 
29
  st.title("Welcome to RegBotBeta2.0")
30
  st.header("Powered by `LlamaIndex🦙`, `Langchain🦜🔗 ` and `OpenAI API`")
31
 
32
+ # openai_api_key = st.text_input(
33
+ # "OpenAI API Key",
34
+ # type="password",
35
+ # help="Get your API key from https://platform.openai.com/account/api-keys",
36
+ # value=st.session_state.openai_api_key,
37
+ # )
38
+
39
+ # isKeyValid = False
40
+ # if openai_api_key:
41
+ # resp = validate(openai_api_key)
42
+ # if "error" in resp.json():
43
+ # st.info("Invalid Token! Try again.")
44
+ # else:
45
+ # st.info("Success")
46
+ # st.session_state.openai_api_key = openai_api_key
47
+ # isKeyValid = True
48
+
49
+ uploaded_files = st.file_uploader(
50
+ "Upload Files",
51
+ accept_multiple_files=True,
52
+ type=["pdf", "docx", "txt", "csv"],
53
+ )
54
+
55
+ if uploaded_files:
56
+ for file in uploaded_files:
57
+ if file not in st.session_state.uploaded_files:
58
+ # add the file to session state
59
+ st.session_state.uploaded_files.append(file.name)
60
 
61
+ # save the file to the sample_data directory
62
+ with open(os.path.join(SAVE_DIR, file.name), "wb") as f:
63
+ f.write(file.getbuffer())
 
 
 
 
 
64
 
65
+ st.success("File(s) uploaded successfully!")
 
66
 
67
+ if st.session_state.uploaded_files:
68
+ st.write("Uploaded Files:")
69
+ for i, filename in enumerate(st.session_state.uploaded_files, start=1):
70
+ st.write(f"{i}. {filename}")
models/embeddings.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from llama_index.embeddings.langchain import LangchainEmbedding
4
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
5
+ from llama_index.embeddings.openai import OpenAIEmbedding
6
+
7
+ hf_embed_model = HuggingFaceEmbeddings(
8
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
9
+ )
10
+
11
+ # embed_model = LangchainEmbedding(hf_embed_model)
12
+
13
+ openai_embed_model = OpenAIEmbedding(
14
+ api_key=os.getenv("OPENAI_API_KEY"),
15
+ )
models/llamaCustom.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import pickle
3
  from json import dumps, loads
 
4
  from typing import Any, List, Mapping, Optional
5
 
6
  import numpy as np
@@ -9,25 +10,31 @@ import pandas as pd
9
  import streamlit as st
10
  from dotenv import load_dotenv
11
  from huggingface_hub import HfFileSystem
12
- from langchain.llms.base import LLM
13
- from llama_index import (
14
- Document,
15
- GPTVectorStoreIndex,
16
- LLMPredictor,
17
- PromptHelper,
18
- ServiceContext,
19
- SimpleDirectoryReader,
20
- StorageContext,
21
- load_index_from_storage,
22
- )
23
- from pydantic import BaseModel
24
- from llama_index.llms.base import llm_completion_callback
25
- from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
26
- from llama_index.prompts import Prompt
27
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Pipeline
28
 
 
29
  from assets.prompts import custom_prompts
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  load_dotenv()
32
  # openai.api_key = os.getenv("OPENAI_API_KEY")
33
  fs = HfFileSystem()
@@ -40,9 +47,41 @@ NUM_OUTPUT = 525
40
  # set maximum chunk overlap
41
  CHUNK_OVERLAP_RATION = 0.2
42
 
43
- text_qa_template = Prompt(custom_prompts.text_qa_template_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- refine_template = Prompt(custom_prompts.refine_template_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  @st.cache_resource
@@ -67,10 +106,8 @@ def load_model(model_name: str):
67
 
68
 
69
  class OurLLM(CustomLLM):
70
- # def __init__(self, model_name: str, pipeline):
71
- # super().__init__() # Call the __init__ method of CustomLLM
72
- # self.model_name = model_name
73
- # self.pipeline = pipeline
74
  model_name: str = ""
75
  pipeline: Pipeline = None
76
 
@@ -83,6 +120,7 @@ class OurLLM(CustomLLM):
83
  model_name=self.model_name,
84
  )
85
 
 
86
  @llm_completion_callback()
87
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
88
  prompt_length = len(prompt)
@@ -93,67 +131,39 @@ class OurLLM(CustomLLM):
93
  return CompletionResponse(text=text)
94
 
95
  @llm_completion_callback()
96
- def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
97
- raise NotImplementedError()
 
 
 
98
 
99
 
100
  class LlamaCustom:
101
- def __init__(self, model_name: str) -> None:
102
- self.vector_index = self.initialize_index(model_name=model_name)
103
-
104
- def initialize_index(self, model_name: str):
105
- index_name = model_name.split("/")[-1]
106
-
107
- file_path = f"./vectorStores/{index_name}"
108
-
109
- if os.path.exists(path=file_path):
110
- # rebuild storage context
111
- storage_context = StorageContext.from_defaults(persist_dir=file_path)
112
-
113
- # local load index access
114
- index = load_index_from_storage(storage_context)
115
-
116
- # huggingface repo load access
117
- # with fs.open(file_path, "r") as file:
118
- # index = pickle.loads(file.readlines())
119
- return index
120
- else:
121
- prompt_helper = PromptHelper(
122
- context_window=CONTEXT_WINDOW,
123
- num_output=NUM_OUTPUT,
124
- chunk_overlap_ratio=CHUNK_OVERLAP_RATION,
125
- )
126
-
127
- # define llm
128
- pipe = load_model(model_name=model_name)
129
- llm = OurLLM(model_name=model_name, pipeline=pipe)
130
-
131
- llm_predictor = LLMPredictor(llm=llm)
132
- service_context = ServiceContext.from_defaults(
133
- llm_predictor=llm_predictor, prompt_helper=prompt_helper
134
- )
135
-
136
- # documents = prepare_data(r"./assets/regItems.json")
137
- documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data()
138
-
139
- index = GPTVectorStoreIndex.from_documents(
140
- documents, service_context=service_context
141
- )
142
-
143
- # local write access
144
- index.storage_context.persist(file_path)
145
-
146
- # huggingface repo write access
147
- # with fs.open(file_path, "w") as file:
148
- # file.write(pickle.dumps(index))
149
- return index
150
-
151
- def get_response(self, query_str):
152
- print("query_str: ", query_str)
153
- # query_engine = self.vector_index.as_query_engine()
154
- query_engine = self.vector_index.as_query_engine(
155
- text_qa_template=text_qa_template, refine_template=refine_template
156
  )
157
- response = query_engine.query(query_str)
158
- print("metadata: ", response.metadata)
 
159
  return str(response)
 
 
 
 
 
 
 
1
  import os
2
  import pickle
3
  from json import dumps, loads
4
+ import time
5
  from typing import Any, List, Mapping, Optional
6
 
7
  import numpy as np
 
10
  import streamlit as st
11
  from dotenv import load_dotenv
12
  from huggingface_hub import HfFileSystem
13
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Pipeline
15
 
16
+ # prompts
17
  from assets.prompts import custom_prompts
18
 
19
+ # llama index
20
+ from llama_index.core import (
21
+ StorageContext,
22
+ SimpleDirectoryReader,
23
+ VectorStoreIndex,
24
+ load_index_from_storage,
25
+ PromptHelper,
26
+ PromptTemplate,
27
+ )
28
+ from llama_index.core.llms import (
29
+ CustomLLM,
30
+ CompletionResponse,
31
+ LLMMetadata,
32
+ )
33
+ from llama_index.core.memory import ChatMemoryBuffer
34
+ from llama_index.core.llms.callbacks import llm_completion_callback
35
+ from llama_index.core.base.llms.types import ChatMessage
36
+ from llama_index.core import Settings
37
+
38
  load_dotenv()
39
  # openai.api_key = os.getenv("OPENAI_API_KEY")
40
  fs = HfFileSystem()
 
47
  # set maximum chunk overlap
48
  CHUNK_OVERLAP_RATION = 0.2
49
 
50
+ # TODO: use the following prompt to format the answer at the end of the context prompt
51
+ ANSWER_FORMAT = """
52
+ Use the following example format for your answer:
53
+ [FORMAT]
54
+ Answer:
55
+ The answer to the user question.
56
+ Reference:
57
+ The list of references to the specific sections of the documents that support your answer.
58
+ [END_FORMAT]
59
+ """
60
+
61
+ CONTEXT_PROMPT_TEMPLATE = """
62
+ The following is a friendly conversation between a user and an AI assistant.
63
+ The assistant is talkative and provides lots of specific details from its context.
64
+ If the assistant does not know the answer to a question, it truthfully says it
65
+ does not know.
66
+
67
+ Here are the relevant documents for the context:
68
 
69
+ {context_str}
70
+
71
+ Instruction: Based on the above documents, provide a detailed answer for the user question below. \
72
+ Include references to the specific sections of the documents that support your answer. \
73
+ Answer "don't know" if not present in the document.
74
+ """
75
+
76
+ CONDENSE_PROMPT_TEMPLATE = """
77
+ Given the following conversation between a user and an AI assistant and a follow up question from user,
78
+ rephrase the follow up question to be a standalone question.
79
+
80
+ Chat History:
81
+ {chat_history}
82
+ Follow Up Input: {question}
83
+ Standalone question:
84
+ """
85
 
86
 
87
  @st.cache_resource
 
106
 
107
 
108
  class OurLLM(CustomLLM):
109
+ context_window: int = 3900
110
+ num_output: int = 256
 
 
111
  model_name: str = ""
112
  pipeline: Pipeline = None
113
 
 
120
  model_name=self.model_name,
121
  )
122
 
123
+ # The decorator is optional, but provides observability via callbacks on the LLM calls.
124
  @llm_completion_callback()
125
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
126
  prompt_length = len(prompt)
 
131
  return CompletionResponse(text=text)
132
 
133
  @llm_completion_callback()
134
+ def stream_complete(self, prompt: str, **kwargs: Any):
135
+ response = ""
136
+ for token in self.dummy_response:
137
+ response += token
138
+ yield CompletionResponse(text=response, delta=token)
139
 
140
 
141
  class LlamaCustom:
142
+ def __init__(self, model_name: str, index: VectorStoreIndex):
143
+ self.model_name = model_name
144
+ self.index = index
145
+ self.chat_mode = "condense_plus_context"
146
+ self.memory = ChatMemoryBuffer.from_defaults()
147
+
148
+ def get_response(self, query_str: str, chat_history: List[ChatMessage]):
149
+ # https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/
150
+ # query_engine = self.index.as_query_engine(
151
+ # text_qa_template=text_qa_template, refine_template=refine_template
152
+ # )
153
+ chat_engine = self.index.as_chat_engine(
154
+ chat_mode=self.chat_mode,
155
+ memory=self.memory,
156
+ context_prompt=CONTEXT_PROMPT_TEMPLATE,
157
+ condense_prompt=CONDENSE_PROMPT_TEMPLATE,
158
+ # verbose=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
+ # response = query_engine.query(query_str)
161
+ response = chat_engine.chat(message=query_str, chat_history=chat_history)
162
+
163
  return str(response)
164
+
165
+ def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
166
+ response = self.get_response(query_str=query_str, chat_history=chat_history)
167
+ for word in response.split():
168
+ yield word + " "
169
+ time.sleep(0.05)
models/llms.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.llms.huggingface import HuggingFaceLLM, HuggingFaceInferenceAPI
2
+ from llama_index.llms.openai import OpenAI
3
+ from llama_index.llms.replicate import Replicate
4
+ from dotenv import load_dotenv
5
+ import os
6
+
7
+ load_dotenv()
8
+
9
+ llm_mixtral_8x7b = HuggingFaceInferenceAPI(
10
+ model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
11
+ token=os.getenv("HUGGINGFACE_API_TOKEN"),
12
+ )
13
+
14
+ # download the model from the Hugging Face Hub and run it locally
15
+ # llm_mixtral_8x7b = HuggingFaceLLM(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
16
+
17
+ llm_llama_2_7b_chat = HuggingFaceInferenceAPI(
18
+ model_name="meta-llama/Llama-2-7b-chat-hf",
19
+ token=os.getenv("HUGGINGFACE_API_TOKEN"),
20
+ )
21
+
22
+ llm_bloomz_560m = HuggingFaceInferenceAPI(
23
+ model_name="bigscience/bloomz-560m",
24
+ token=os.getenv("HUGGINGFACE_API_TOKEN"),
25
+ )
26
+
27
+ llm_gpt_3_5_turbo = OpenAI(
28
+ api_key=os.getenv("OPENAI_API_KEY"),
29
+ )
30
+
31
+ llm_gpt_3_5_turbo_0125 = OpenAI(
32
+ model="gpt-3.5-turbo-0125",
33
+ api_key="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",
34
+ )
35
+
36
+ llm_gpt_4_0125 = OpenAI(
37
+ model="gpt-4-0125-preview",
38
+ api_key=os.getenv("OPENAI_API_KEY"),
39
+ )
40
+
41
+ llm_llama_13b_v2_replicate = Replicate(
42
+ model="meta/llama-2-13b-chat",
43
+ prompt_key=os.getenv("REPLICATE_API_KEY"),
44
+ )
pages/llama_custom_demo.py CHANGED
@@ -1,26 +1,158 @@
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- import openai
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
5
 
6
- from models.llamaCustom import LlamaCustom
7
- from utils.chatbox import chatbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  st.set_page_config(page_title="Llama", page_icon="🦙")
10
 
11
- st.subheader("Llama Index with Custom LLM Demo")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- if "messages" not in st.session_state:
14
- st.session_state.messages = []
 
 
15
 
16
- if "openai_api_key" not in st.session_state:
17
- st.info("Enter your openai key to access the chatbot.")
18
- else:
19
- option = st.selectbox(
20
- label="Select your model:", options=("bigscience/bloom-560m",)
21
- )
22
 
23
- # with st.spinner("Initializing vector index"):
24
- model = LlamaCustom(model_name=option)
 
25
 
26
- chatbox("llama_custom", model)
 
 
 
 
1
+ import random
2
+ import time
3
+ import streamlit as st
4
  import os
5
+ import pathlib
6
+ from typing import List
7
+ from models.llms import (
8
+ llm_llama_2_7b_chat,
9
+ llm_mixtral_8x7b,
10
+ llm_bloomz_560m,
11
+ llm_gpt_3_5_turbo,
12
+ llm_gpt_3_5_turbo_0125,
13
+ llm_gpt_4_0125,
14
+ llm_llama_13b_v2_replicate
15
+ )
16
+ from models.embeddings import hf_embed_model, openai_embed_model
17
+ from models.llamaCustom import LlamaCustom
18
 
19
+ # from models.llamaCustom import LlamaCustom
20
+ from utils.chatbox import show_previous_messages, show_chat_input
21
+ from llama_index.core import (
22
+ SimpleDirectoryReader,
23
+ Document,
24
+ VectorStoreIndex,
25
+ StorageContext,
26
+ Settings,
27
+ load_index_from_storage,
28
+ )
29
+ from llama_index.core.memory import ChatMemoryBuffer
30
+ from llama_index.core.base.llms.types import ChatMessage
31
 
32
+ SAVE_DIR = "uploaded_files"
33
+ VECTOR_STORE_DIR = "vectorStores"
34
+
35
+ # global
36
+ Settings.embed_model = hf_embed_model
37
+
38
+ llama_llms = {
39
+ "bigscience/bloomz-560m": llm_bloomz_560m,
40
+ "mistral/mixtral": llm_mixtral_8x7b,
41
+ "meta-llama/Llama-2-7b-chat-hf": llm_llama_2_7b_chat,
42
+ # "openai/gpt-3.5-turbo": llm_gpt_3_5_turbo,
43
+ "openai/gpt-3.5-turbo-0125": llm_gpt_3_5_turbo_0125,
44
+ # "openai/gpt-4-0125-preview": llm_gpt_4_0125,
45
+ # "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate,
46
+ }
47
+
48
+ def init_session_state():
49
+ if "llama_messages" not in st.session_state:
50
+ st.session_state.llama_messages = [
51
+ {"role": "assistant", "content": "How can I help you today?"}
52
+ ]
53
+
54
+ # TODO: create a chat history for each different document
55
+ if "llama_chat_history" not in st.session_state:
56
+ st.session_state.llama_chat_history = [
57
+ ChatMessage.from_str(role="assistant", content="How can I help you today?")
58
+ ]
59
+
60
+ if "llama_custom" not in st.session_state:
61
+ st.session_state.llama_custom = None
62
+
63
+ # @st.cache_resource
64
+ def index_docs(
65
+ filename: str,
66
+ ) -> VectorStoreIndex:
67
+ try:
68
+ index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
69
+ if pathlib.Path.exists(index_path):
70
+ print("Loading index from storage ...")
71
+ storage_context = StorageContext.from_defaults(persist_dir=index_path)
72
+ index = load_index_from_storage(storage_context=storage_context)
73
+
74
+ # test the index
75
+ index.as_query_engine().query("What is the capital of France?")
76
+
77
+ else:
78
+ reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
79
+ docs = reader.load_data(show_progress=True)
80
+ index = VectorStoreIndex.from_documents(
81
+ documents=docs,
82
+ show_progress=True,
83
+ )
84
+ index.storage_context.persist(persist_dir=f"vectorStores/{filename.replace(".", '_')}")
85
+
86
+ except Exception as e:
87
+ print(f"Error: {e}")
88
+ index = None
89
+ return index
90
+
91
+
92
+ def load_llm(model_name: str):
93
+ return llama_llms[model_name]
94
+
95
+ init_session_state()
96
 
97
  st.set_page_config(page_title="Llama", page_icon="🦙")
98
 
99
+ st.header("Llama Index with Custom LLM Demo")
100
+
101
+ tab1, tab2 = st.tabs(["Config", "Chat"])
102
+
103
+ with tab1:
104
+ with st.form(key="llama_form"):
105
+ selected_llm_name = st.selectbox(label="Select a model:", options=llama_llms.keys())
106
+
107
+ if selected_llm_name.startswith("openai"):
108
+ # ask for the api key
109
+ if st.secrets.get("OPENAI_API_KEY") is None:
110
+ # st.stop()
111
+ st.info("OpenAI API Key not found in secrets. Please enter it below.")
112
+ st.secrets["OPENAI_API_KEY"] = st.text_input(
113
+ "OpenAI API Key",
114
+ type="password",
115
+ help="Get your API key from https://platform.openai.com/account/api-keys",
116
+ )
117
+
118
+ selected_file = st.selectbox(
119
+ label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
120
+ )
121
+
122
+ if st.form_submit_button(label="Submit"):
123
+ with st.status("Loading ...", expanded=True) as status:
124
+ st.write("Loading Model ...")
125
+ llama_llm = load_llm(selected_llm_name)
126
+ Settings.llm = llama_llm
127
+
128
+ st.write("Processing Data ...")
129
+ index = index_docs(selected_file)
130
+ if index is None:
131
+ st.error("Failed to index the documents.")
132
+ st.stop()
133
+
134
+ st.write("Finishing Up ...")
135
+ llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
136
+ st.session_state.llama_custom = llama_custom
137
+
138
+ status.update(label="Ready to query!", state="complete", expanded=False)
139
 
140
+ with tab2:
141
+ messages_container = st.container(height=300)
142
+ show_previous_messages(framework="llama", messages_container=messages_container)
143
+ show_chat_input(disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container)
144
 
145
+ def clear_history():
146
+ messages_container.empty()
147
+ st.session_state.llama_messages = [
148
+ {"role": "assistant", "content": "How can I help you today?"}
149
+ ]
 
150
 
151
+ st.session_state.llama_chat_history = [
152
+ ChatMessage.from_str(role="assistant", content="How can I help you today?")
153
+ ]
154
 
155
+ if st.button("Clear Chat History"):
156
+ clear_history()
157
+ st.rerun()
158
+
requirements.txt CHANGED
@@ -1,14 +1,15 @@
1
- llama_index==0.8.64.post1
2
- torch
3
  transformers
4
- panda
5
- numpy
6
- langchain
7
  openai
8
  faiss-cpu
9
  python-dotenv
10
  streamlit>=1.24.0
11
- huggingface_hub
12
- xformers
13
  pypdf
14
- pymupdf
 
 
 
 
 
1
+ llama_index>=0.10.27
 
2
  transformers
3
+ pandas
4
+ langchain>=0.1.11
 
5
  openai
6
  faiss-cpu
7
  python-dotenv
8
  streamlit>=1.24.0
9
+ huggingface_hub>=0.21.4
 
10
  pypdf
11
+ llama-index-llms-huggingface>=0.1.4
12
+ llama-index-embeddings-langchain>=0.1.2
13
+ replicate>=0.25.1
14
+ llama-index-llms-replicate
15
+ sentence-transformers>=2.6.1
utils/chatbox.py CHANGED
@@ -1,93 +1,61 @@
1
  import time
2
-
3
  import streamlit as st
 
4
 
5
- def display_chat_history(model_name: str):
6
- for message in st.session_state[model_name]:
7
- with st.chat_message(message["role"]):
8
- st.markdown(message["content"])
9
-
10
- def chat_input(model_name: str):
11
- if prompt := st.chat_input("Say something"):
12
- # Display user message in chat message container
13
- st.chat_message("user").markdown(prompt)
14
-
15
- # Add user message to chat history
16
- st.session_state[model_name].append({"role": "user", "content": prompt})
17
-
18
- return prompt
19
-
20
- def display_bot_msg(model_name: str, bot_response: str):
21
- # Display assistant response in chat message container
22
- with st.chat_message("assistant"):
23
- message_placeholder = st.empty()
24
- full_response = ""
25
-
26
- # simulate the chatbot "thinking" before responding
27
- # (or stream its response)
28
- for chunk in bot_response.split():
29
- full_response += chunk + " "
30
- time.sleep(0.05)
31
-
32
- # add a blinking cursor to simulate typing
33
- message_placeholder.markdown(full_response + "▌")
34
 
35
- message_placeholder.markdown(full_response)
36
- # st.markdown(response)
37
 
38
- # Add assistant response to chat history
39
- st.session_state[model_name].append(
40
- {"model_name": model_name, "role": "assistant", "content": full_response}
41
- )
42
-
43
- def chatbox(model_name: str, model: None):
44
- # Display chat messages from history on app rerun
45
- for message in st.session_state.messages:
46
- if (message["model_name"] == model_name):
47
  with st.chat_message(message["role"]):
48
  st.markdown(message["content"])
49
 
50
- if prompt := st.chat_input("Say something"):
51
- # Display user message in chat message container
52
- st.chat_message("user").markdown(prompt)
53
-
54
- # Add user message to chat history
55
- st.session_state.messages.append({"model_name": model_name, "role": "user", "content": prompt})
56
-
57
- with st.spinner("Processing your query..."):
58
- bot_response = model.get_response(prompt)
59
-
60
- print("bot: ", bot_response)
61
 
62
- # Display assistant response in chat message container
63
- with st.chat_message("assistant"):
64
- message_placeholder = st.empty()
65
- full_response = ""
 
66
 
67
- # simulate the chatbot "thinking" before responding
68
- # (or stream its response)
69
- for chunk in bot_response.split():
70
- full_response += chunk + " "
71
- time.sleep(0.05)
72
-
73
- # add a blinking cursor to simulate typing
74
- message_placeholder.markdown(full_response + "▌")
75
-
76
- message_placeholder.markdown(full_response)
77
- # st.markdown(response)
78
 
79
- # Add assistant response to chat history
80
- st.session_state.messages.append(
81
- {"model_name": model_name, "role": "assistant", "content": full_response}
82
  )
83
 
84
- # Scroll to the bottom of the chat container
85
- # st.markdown(
86
- # """
87
- # <script>
88
- # const chatContainer = document.getElementsByClassName("css-1n76uvr")[0];
89
- # chatContainer.scrollTop = chatContainer.scrollHeight;
90
- # </script>
91
- # """,
92
- # unsafe_allow_html=True,
93
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
+ import json
3
  import streamlit as st
4
+ from typing import Dict, List, Any
5
 
6
+ from llama_index.core.base.llms.types import ChatMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
8
 
9
+ def show_previous_messages(framework: str, messages_container: any):
10
+ with messages_container:
11
+ messages: List[Dict[str, Any]] = st.session_state[f"{framework}_messages"]
12
+ for message in messages:
 
 
 
 
 
13
  with st.chat_message(message["role"]):
14
  st.markdown(message["content"])
15
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def show_chat_input(
18
+ disabled: bool, framework: str, model: any, messages_container: any
19
+ ):
20
+ if disabled:
21
+ st.info("Make sure to select a model and file to start chatting!")
22
 
23
+ if prompt := st.chat_input("Say something", disabled=disabled):
24
+ st.session_state[f"{framework}_messages"].append(
25
+ {"role": "user", "content": prompt}
26
+ )
 
 
 
 
 
 
 
27
 
28
+ st.session_state[f"{framework}_chat_history"].append(
29
+ ChatMessage.from_str(role="user", content=prompt)
 
30
  )
31
 
32
+ # if st.session_state[f"{framework}_messages"][-1]["role"] == "assistant":
33
+ with messages_container:
34
+ with st.chat_message("user"):
35
+ st.write(prompt)
36
+
37
+ with st.chat_message("assistant"):
38
+ with st.spinner("Thinking..."):
39
+ try:
40
+ ai_response = model.get_response(
41
+ query_str=prompt,
42
+ chat_history=st.session_state[f"{framework}_chat_history"],
43
+ )
44
+ # when streaming, the response format is gone
45
+ # ai_response = model.get_stream_response(
46
+ # query_str=prompt,
47
+ # chat_history=st.session_state[f"{framework}_chat_history"],
48
+ # )
49
+ except Exception as e:
50
+ ai_response = f"An error occurred: {e}"
51
+
52
+ st.write(ai_response)
53
+ # response = st.write_stream(ai_response)
54
+
55
+ st.session_state[f"{framework}_messages"].append(
56
+ {"role": "assistant", "content": ai_response}
57
+ )
58
+
59
+ st.session_state[f"{framework}_chat_history"].append(
60
+ ChatMessage.from_str(role="assistant", content=ai_response)
61
+ )