TFM / app.py
kyleebrooks's picture
Update app.py
2e58384
raw
history blame
10.2 kB
from llama_index import SimpleDirectoryReader, Prompt, LLMPredictor, VectorStoreIndex, PromptHelper, ServiceContext, load_index_from_storage, StorageContext
from langchain.chat_models import ChatOpenAI
import os
import openai
import tiktoken
import gradio as gr
#if you have OpenAI API key as an environment variable, enable the below
#openai.api_key = os.getenv("OPENAI_API_KEY")
#if you have OpenAI API key as a string, enable the below
openai.api_key = ""
os.environ["OPENAI_API_KEY"] = ''
#set system prompt
#messages = [
# {"role":"system", "content": "This Chatbot is helpful, accurate, and provides all relevant information from the Treasury Financial Manual (TFM) when responding. This Chatbot provides accurate citations from the Treasury financial Manual in the following format for all reponses, Volume>Part>chapter>Sub Chapter"},
# {"role":"user","content": "where should public money be stored?"},
# {"role":"assistant","content": "Public money should be stored in Treasury accounts as outlined in the Treasury Financial Manual (TFM). Specifically, Section 2025 of the TFM provides guidance on requesting a Treasury account. Entities should submit requests to the Fiscal Service, Budget Reporting Branch, Budget and Appropriations Analysis Section (BAAS) at baasgroup@fiscal.treasury.gov. It is recommended to submit these requests as early in the month as possible to allow for timely establishment and coordination. This information was located in Volume I>Part 1>Chapter 2000>Section 2045—Making Deposits"}
#]
messages = []
Chat_message = []
chat_history=[]
custom_chat_history = []
max_input_size = 4096
num_outputs = 512
chunk_size_limit = 1000
chunk_overlap_ratio = .1
prompt_helper = PromptHelper(max_input_size, num_outputs, chunk_overlap_ratio, chunk_size_limit)
store = './storage'
max_response_tokens = 1000
token_limit= 4097
template = (
"This Chatbot is helpful, accurate, and provides a detailed summarized response to all questions. \n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Given this information, please answer the question and provide accurate citations from the Treasury financial Manual in the following format, Volume>Part>chapter>Section: {query_str}\n"
)
qa_template = Prompt(template)
#create index prompt value
#createIndex = input('Do you wish to recreate the index (Y/N)? ')
#set LLM predictor values
#llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs))
#Set service context
#service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
#loads openai key
def load_api_key (api_key):
os.environ["OPENAI_API_KEY"] = str(api_key)
openai.api_key = str(api_key)
#identifies the current number of tokens used for the conversation
def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):
encoding = tiktoken.encoding_for_model(model)
num_tokens = 0
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
print(num_tokens)
return num_tokens
#constructs the index and saves to a subfolder
def construct_index():
print('Constructing indx')
# load in the documents from the docs subfolder
docs = SimpleDirectoryReader('docs').load_data()
index = VectorStoreIndex.from_documents(docs, service_context=service_context)
index.set_index_id('vector_index')
#Stores json files in a subfolder
index.storage_context.persist(persist_dir=store)
return index
#resets the conversation
def generate_restart(prompt):
messages.clear()
messages.append({"role":"system", "content": "Tell the user that this conversation has been reset due to the discussion size reaching maximum size, and to please start by asking a new question."})
storage_context = StorageContext.from_defaults(persist_dir=store)
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
#index = load_index_from_storage(storage_context)
index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=store),
service_context=service_context,
)
#query_engine = index.as_query_engine(text_qa_template=qa_template)
chat_engine = index.as_chat_engine(text_qa_template=qa_template)
string_message=str(messages)
#response = query_engine.query(string_message)
response = chat_engine.chat(messages)
messages.clear()
messages.append({"role":"system", "content": "This Chatbot is helpful, accurate, and provides all relevnt information from the Treasury Financial Manual (TFM) when responding. This Chatbot always provides accurate citations from the TFM."})
messages.append({"role":"user","content": ""})
messages.append({"role":"assistant","content": ""})
print("restert initiated")
print(messages)
return response.response
#generates the ChatGPT call
def generate_response(prompt):
messages.append({"role": "user", "content": prompt})
storage_context = StorageContext.from_defaults(persist_dir=store)
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs))
service_context = ServiceContext.from_defaults(llm=ChatOpenAI(temperature=0., model_name="gpt-3.5-turbo"))
#service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
#index = load_index_from_storage(storage_context)
index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=store),
service_context=service_context,
)
#query_engine = index.as_query_engine(text_qa_template=qa_template)
#chat_engine = index.as_chat_engine(chat_mode='react', verbose=True, text_qa_template=qa_template)
chat_engine = index.as_chat_engine(verbose=True, text_qa_template=qa_template)
#chat_engine = index.SimpleChatEngine.from_defaults(service_context=service_context)
#response = chat_engine.chat_repl(messages)
string_message=str(messages)
#response = query_engine.query(string_message)
response = chat_engine.chat(prompt, chat_history)
string_response=str(response)
custom_chat_history.append("Human: "+prompt+"Assistant: "+string_response)
#Optional[List[ChatMessage]] = None) → Union[Response, StreamingResponse]
messages.append({"role": "assistant", "content":string_response})
#messages.clear()
#string_message=''
num_tokens_from_messages(messages)
print(messages)
print("below is history")
print(chat_history)
print("below is type")
print(type(custom_chat_history))
type(chat_history)
return response.response
def generate_response_backup(prompt):
messages.append({"role": "user", "content": prompt})
storage_context = StorageContext.from_defaults(persist_dir=store)
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs))
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
#index = load_index_from_storage(storage_context)
index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=store),
service_context=service_context,
)
query_engine = index.as_query_engine(text_qa_template=qa_template)
string_message=str(messages)
response = query_engine.query(string_message)
string_response=str(response)
messages.append({"role": "assistant", "content":string_response})
string_message=''
num_tokens_from_messages(messages)
print(messages)
return response.response
#Called by the send button on the interface and tracks the history of the conversation
def my_chatbot_backup(input, history):
history = history or []
if num_tokens_from_messages(messages)<(int(token_limit)-int(max_response_tokens)):
output = generate_response(input)
history.append((input, output))
return history, history
else:
history.clear()
output = generate_restart(input)
history.append((input, output))
return history, history
def my_chatbot(input, history):
history = history or []
if num_tokens_from_messages(messages)<(int(token_limit)-int(max_response_tokens)):
output = generate_response(input)
history.append((input, output))
return history, history
else:
history.clear()
output = generate_restart(input)
history.append((input, output))
prompt=input
return prompt, prompt
def index_chatbot(input_text):
if not hasattr(chatbot, 'index'):
storage_context = StorageContext.from_defaults(persist_dir=store)
index = load_index_from_storage(storage_context)
query_engine = chatbot.index.as_query_engine(text_qa_template=QA_TEMPLATE)
response = chatbot.query_engine.query(input_text)
return response.response
with gr.Blocks() as demo:
gr.Markdown("""<h1><center>TFM Chatbot 2.0</center></h1>""")
api_key = gr.Textbox(type='password', label="Enter the API key")
chatbot = gr.Chatbot()
state = gr.State()
text = gr.Textbox(placeholder="Hello. Ask me a question about the Treasury Financial Manual.")
submit = gr.Button("SEND")
submit.click(load_api_key, [api_key])
submit.click(my_chatbot, inputs=[text, state], outputs=[chatbot, state])
#if createIndex in ["Y","y"]:
# construct_index()
demo.launch(share = False)