Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import openai | |
import requests | |
import csv | |
import faiss | |
import tiktoken | |
from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.embeddings import HuggingFaceHubEmbeddings | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent | |
from langchain.memory import ChatMessageHistory | |
from langchain.schema.messages import SystemMessage | |
from sentence_transformers import SentenceTransformer, util | |
LLM_MODEL = 'gpt-4-1106-preview' | |
OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY') | |
class AASLDConversationalAgent(): | |
def __init__(self) -> None: | |
# loader = PyPDFLoader( | |
# 'docs/aasld_practice_guidance_on_prevention_diagnosis441.pdf') | |
loader = PyPDFLoader( | |
'aasld_practice_guidance_on_prevention_diagnosis441.pdf') | |
pages = loader.load_and_split() | |
embeddings_model = OpenAIEmbeddings(openai_api_key=OPEN_AI_KEY) | |
# embeddings_model = HuggingFaceEmbeddings( | |
# model_name='sentence-transformers/all-mpnet-base-v2', | |
# model_kwargs={'device': 'cpu'}, | |
# encode_kwargs={'normalize_embeddings': False} | |
# ) | |
self.doc_search = FAISS.from_documents(pages, embeddings_model) | |
self.doc_retriever = self.doc_search.as_retriever( | |
search_type="mmr", search_kwargs={"k": 10}) | |
self.llm = ChatOpenAI(temperature = 0, model_name=LLM_MODEL, | |
openai_api_key=OPEN_AI_KEY) | |
tool = create_retriever_tool( | |
self.doc_retriever, | |
"SearchAASLDDocs", | |
"""Searches and returns documents regarding the AASLD Practice Guidance on | |
Prevention, Diagnosis, and Treatment of Hepatocellular Carcinoma.""", | |
) | |
self.tools = [tool] | |
system_message = SystemMessage ( | |
content=""" | |
You are a chat assistant who searches the document for the given | |
question and returns the summary of the results as a HTML body with | |
formatting. The results text should be left aligned. | |
If the document does not have the answers please let the users know | |
it is unavailable and do not try to answer. | |
""" | |
) | |
self.rag_chat_agent = create_conversational_retrieval_agent( | |
self.llm, self.tools, system_message=system_message, | |
verbose=False, | |
remember_intermediate_steps=False, max_token_limit=8000) | |
def _get_html_content(self, res:str): | |
print(res) | |
return res.split('```html')[1].replace('```', '') | |
def get_answer(self, question: str): | |
res = self.rag_chat_agent(question) | |
if '```html' in res['output']: | |
html_res = self._get_html_content(res['output']) | |
else: | |
html_res = res['output'] | |
return html_res | |
AASLD_CONVERSATIONAL_AGENT = AASLDConversationalAgent() | |
PAGE_TITLE_HTML = """ | |
<h1> Arithmedics - AASLD Guidelines Chat Assistant</h1> | |
<h3> I am AASLD-GPT, trained on the AASLD guidelines for Hepatocellular Carcinoma (HCC). My role is to assist you in navigating and extracting information from these guidelines. As an early prototype, your patience and feedback are essential for my development and improvement. </h3> | |
""" | |
def get_empty_state(): | |
return [] | |
def clear_conversation(): | |
return (gr.update(value=None, visible=True), None, '', get_empty_state(), | |
'', '') | |
def submit_question(question): | |
print(question) | |
try: | |
answer = AASLD_CONVERSATIONAL_AGENT.get_answer(question) | |
print(answer) | |
except Exception as e: | |
print(e) | |
answer = e['error']['message'] | |
question = '<b>{}</b'.format(question) | |
return ('', question, answer) | |
css = """ | |
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;} | |
#arithmedics-img {width: 200px; display: block; margin: auto;} | |
#submitbtn {width: 50%; margin:auto} | |
#chatbox {min-height: 400px; .scroll-container {position: absolute; top: 0; right:0;}} | |
#header {text-align: center;} | |
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;} | |
#label {font-size: 0.8em; padding: 0.5em; margin: 0;} | |
.message { font-size: 1.2em; } | |
""" | |
with gr.Blocks(css=css, title='AASLD Practice Guidelines Chat Assistant', | |
) as demo: | |
# messages = gr.State(get_empty_state()) | |
with gr.Column(elem_id='col-container'): | |
# gr.Image('Arithmedics_rectangle_logo.png', | |
# width=200, | |
# show_label=False, | |
# show_download_button=False, | |
# container=False, | |
# show_share_button=False, | |
# elem_id='arithmedics-img') | |
gr.HTML(PAGE_TITLE_HTML, elem_id='header') | |
with gr.Row(): | |
with gr.Column(scale=7): | |
# btn_clear_conversation = gr.Button('π Start New Conversation') | |
input_message = gr.Textbox( | |
show_label=False, | |
placeholder='Enter your question', | |
visible=True, container=False) | |
btn_submit = gr.Button( | |
'Ask AASLD Guidelines', | |
variant='primary', | |
elem_id='submitbtn', | |
size='sm') | |
question = gr.HTML() | |
answer = gr.HTML() | |
# chatbot = gr.Chatbot( | |
# elem_id='chatbox', visible=False) | |
# total_tokens_str = gr.Markdown(elem_id='total_tokens_str') | |
btn_submit.click( | |
submit_question, | |
[input_message], | |
[input_message, question, answer] | |
) | |
input_message.submit( | |
submit_question, | |
[input_message], | |
[input_message, question, answer] | |
) | |
# btn_clear_conversation.click( | |
# clear_conversation, [], | |
# [input_message, messages, question, answer]) | |
# demo.queue(concurrency_count=10) | |
if __name__ == "__main__": | |
demo.launch(height='800px', share=True) |