Spaces:
Sleeping
Sleeping
from llama_index import SimpleDirectoryReader, Prompt, LLMPredictor, VectorStoreIndex, PromptHelper, ServiceContext, load_index_from_storage, StorageContext | |
from langchain.chat_models import ChatOpenAI | |
import os | |
import openai | |
import tiktoken | |
import gradio as gr | |
#if you have OpenAI API key as an environment variable, enable the below | |
#openai.api_key = os.getenv("OPENAI_API_KEY") | |
#if you have OpenAI API key as a string, enable the below | |
openai.api_key = "" | |
os.environ["OPENAI_API_KEY"] = '' | |
#set system prompt | |
#messages = [ | |
# {"role":"system", "content": "This Chatbot is helpful, accurate, and provides all relevant information from the Treasury Financial Manual (TFM) when responding. This Chatbot provides accurate citations from the Treasury financial Manual in the following format for all reponses, Volume>Part>chapter>Sub Chapter"}, | |
# {"role":"user","content": "where should public money be stored?"}, | |
# {"role":"assistant","content": "Public money should be stored in Treasury accounts as outlined in the Treasury Financial Manual (TFM). Specifically, Section 2025 of the TFM provides guidance on requesting a Treasury account. Entities should submit requests to the Fiscal Service, Budget Reporting Branch, Budget and Appropriations Analysis Section (BAAS) at baasgroup@fiscal.treasury.gov. It is recommended to submit these requests as early in the month as possible to allow for timely establishment and coordination. This information was located in Volume I>Part 1>Chapter 2000>Section 2045—Making Deposits"} | |
#] | |
messages = [] | |
Chat_message = [] | |
chat_history=[] | |
custom_chat_history = [] | |
max_input_size = 4096 | |
num_outputs = 512 | |
chunk_size_limit = 1000 | |
chunk_overlap_ratio = .1 | |
prompt_helper = PromptHelper(max_input_size, num_outputs, chunk_overlap_ratio, chunk_size_limit) | |
store = './storage' | |
max_response_tokens = 1000 | |
token_limit= 4097 | |
template = ( | |
"This Chatbot is helpful, accurate, and provides a detailed summarized response to all questions. \n" | |
"---------------------\n" | |
"{context_str}" | |
"\n---------------------\n" | |
"Given this information, please answer the question and provide accurate citations from the Treasury financial Manual in the following format, Volume>Part>chapter>Section: {query_str}\n" | |
) | |
qa_template = Prompt(template) | |
#create index prompt value | |
#createIndex = input('Do you wish to recreate the index (Y/N)? ') | |
#set LLM predictor values | |
#llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs)) | |
#Set service context | |
#service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper) | |
#loads openai key | |
def load_api_key (api_key): | |
os.environ["OPENAI_API_KEY"] = str(api_key) | |
openai.api_key = str(api_key) | |
#identifies the current number of tokens used for the conversation | |
def num_tokens_from_messages(messages, model="gpt-3.5-turbo"): | |
encoding = tiktoken.encoding_for_model(model) | |
num_tokens = 0 | |
for message in messages: | |
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n | |
for key, value in message.items(): | |
num_tokens += len(encoding.encode(value)) | |
if key == "name": # if there's a name, the role is omitted | |
num_tokens += -1 # role is always required and always 1 token | |
num_tokens += 2 # every reply is primed with <im_start>assistant | |
print(num_tokens) | |
return num_tokens | |
#constructs the index and saves to a subfolder | |
def construct_index(): | |
print('Constructing indx') | |
# load in the documents from the docs subfolder | |
docs = SimpleDirectoryReader('docs').load_data() | |
index = VectorStoreIndex.from_documents(docs, service_context=service_context) | |
index.set_index_id('vector_index') | |
#Stores json files in a subfolder | |
index.storage_context.persist(persist_dir=store) | |
return index | |
#resets the conversation | |
def generate_restart(prompt): | |
messages.clear() | |
messages.append({"role":"system", "content": "Tell the user that this conversation has been reset due to the discussion size reaching maximum size, and to please start by asking a new question."}) | |
storage_context = StorageContext.from_defaults(persist_dir=store) | |
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs)) | |
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper) | |
#index = load_index_from_storage(storage_context) | |
index = load_index_from_storage( | |
StorageContext.from_defaults(persist_dir=store), | |
service_context=service_context, | |
) | |
#query_engine = index.as_query_engine(text_qa_template=qa_template) | |
chat_engine = index.as_chat_engine(text_qa_template=qa_template) | |
string_message=str(messages) | |
#response = query_engine.query(string_message) | |
response = chat_engine.chat(messages) | |
messages.clear() | |
messages.append({"role":"system", "content": "This Chatbot is helpful, accurate, and provides all relevnt information from the Treasury Financial Manual (TFM) when responding. This Chatbot always provides accurate citations from the TFM."}) | |
messages.append({"role":"user","content": ""}) | |
messages.append({"role":"assistant","content": ""}) | |
print("restert initiated") | |
print(messages) | |
return response.response | |
#generates the ChatGPT call | |
def generate_response(prompt): | |
messages.append({"role": "user", "content": prompt}) | |
storage_context = StorageContext.from_defaults(persist_dir=store) | |
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs)) | |
service_context = ServiceContext.from_defaults(llm=ChatOpenAI(temperature=0., model_name="gpt-3.5-turbo")) | |
#service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper) | |
#index = load_index_from_storage(storage_context) | |
index = load_index_from_storage( | |
StorageContext.from_defaults(persist_dir=store), | |
service_context=service_context, | |
) | |
#query_engine = index.as_query_engine(text_qa_template=qa_template) | |
#chat_engine = index.as_chat_engine(chat_mode='react', verbose=True, text_qa_template=qa_template) | |
chat_engine = index.as_chat_engine(verbose=True, text_qa_template=qa_template) | |
#chat_engine = index.SimpleChatEngine.from_defaults(service_context=service_context) | |
#response = chat_engine.chat_repl(messages) | |
string_message=str(messages) | |
#response = query_engine.query(string_message) | |
response = chat_engine.chat(prompt, chat_history) | |
string_response=str(response) | |
custom_chat_history.append("Human: "+prompt+"Assistant: "+string_response) | |
#Optional[List[ChatMessage]] = None) → Union[Response, StreamingResponse] | |
messages.append({"role": "assistant", "content":string_response}) | |
#messages.clear() | |
#string_message='' | |
num_tokens_from_messages(messages) | |
print(messages) | |
print("below is history") | |
print(chat_history) | |
print("below is type") | |
print(type(custom_chat_history)) | |
type(chat_history) | |
return response.response | |
def generate_response_backup(prompt): | |
messages.append({"role": "user", "content": prompt}) | |
storage_context = StorageContext.from_defaults(persist_dir=store) | |
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo", max_tokens=num_outputs)) | |
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper) | |
#index = load_index_from_storage(storage_context) | |
index = load_index_from_storage( | |
StorageContext.from_defaults(persist_dir=store), | |
service_context=service_context, | |
) | |
query_engine = index.as_query_engine(text_qa_template=qa_template) | |
string_message=str(messages) | |
response = query_engine.query(string_message) | |
string_response=str(response) | |
messages.append({"role": "assistant", "content":string_response}) | |
string_message='' | |
num_tokens_from_messages(messages) | |
print(messages) | |
return response.response | |
#Called by the send button on the interface and tracks the history of the conversation | |
def my_chatbot_backup(input, history): | |
history = history or [] | |
if num_tokens_from_messages(messages)<(int(token_limit)-int(max_response_tokens)): | |
output = generate_response(input) | |
history.append((input, output)) | |
return history, history | |
else: | |
history.clear() | |
output = generate_restart(input) | |
history.append((input, output)) | |
return history, history | |
def my_chatbot(input, history): | |
history = history or [] | |
if num_tokens_from_messages(messages)<(int(token_limit)-int(max_response_tokens)): | |
output = generate_response(input) | |
history.append((input, output)) | |
return history, history | |
else: | |
history.clear() | |
output = generate_restart(input) | |
history.append((input, output)) | |
prompt=input | |
return prompt, prompt | |
def index_chatbot(input_text): | |
if not hasattr(chatbot, 'index'): | |
storage_context = StorageContext.from_defaults(persist_dir=store) | |
index = load_index_from_storage(storage_context) | |
query_engine = chatbot.index.as_query_engine(text_qa_template=QA_TEMPLATE) | |
response = chatbot.query_engine.query(input_text) | |
return response.response | |
with gr.Blocks() as demo: | |
gr.Markdown("""<h1><center>TFM Chatbot 2.0</center></h1>""") | |
api_key = gr.Textbox(type='password', label="Enter the API key") | |
chatbot = gr.Chatbot() | |
state = gr.State() | |
text = gr.Textbox(placeholder="Hello. Ask me a question about the Treasury Financial Manual.") | |
submit = gr.Button("SEND") | |
submit.click(load_api_key, [api_key]) | |
submit.click(my_chatbot, inputs=[text, state], outputs=[chatbot, state]) | |
#if createIndex in ["Y","y"]: | |
# construct_index() | |
demo.launch(share = False) | |