venthan's picture
Cleanup html
f423da9 verified
import gradio as gr
import os
import openai
import requests
import csv
import faiss
import tiktoken
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.memory import ChatMessageHistory
from langchain.schema.messages import SystemMessage
from sentence_transformers import SentenceTransformer, util
LLM_MODEL = 'gpt-4-1106-preview'
OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY')
class AASLDConversationalAgent():
def __init__(self) -> None:
# loader = PyPDFLoader(
# 'docs/aasld_practice_guidance_on_prevention_diagnosis441.pdf')
loader = PyPDFLoader(
'aasld_practice_guidance_on_prevention_diagnosis441.pdf')
pages = loader.load_and_split()
embeddings_model = OpenAIEmbeddings(openai_api_key=OPEN_AI_KEY)
# embeddings_model = HuggingFaceEmbeddings(
# model_name='sentence-transformers/all-mpnet-base-v2',
# model_kwargs={'device': 'cpu'},
# encode_kwargs={'normalize_embeddings': False}
# )
self.doc_search = FAISS.from_documents(pages, embeddings_model)
self.doc_retriever = self.doc_search.as_retriever(
search_type="mmr", search_kwargs={"k": 10})
self.llm = ChatOpenAI(temperature = 0, model_name=LLM_MODEL,
openai_api_key=OPEN_AI_KEY)
tool = create_retriever_tool(
self.doc_retriever,
"SearchAASLDDocs",
"""Searches and returns documents regarding the AASLD Practice Guidance on
Prevention, Diagnosis, and Treatment of Hepatocellular Carcinoma.""",
)
self.tools = [tool]
system_message = SystemMessage (
content="""
You are a chat assistant who searches the document for the given
question and returns the summary of the results as a HTML body with
formatting. The results text should be left aligned.
If the document does not have the answers please let the users know
it is unavailable and do not try to answer.
"""
)
self.rag_chat_agent = create_conversational_retrieval_agent(
self.llm, self.tools, system_message=system_message,
verbose=False,
remember_intermediate_steps=False, max_token_limit=8000)
def _get_html_content(self, res:str):
print(res)
return res.split('```html')[1].replace('```', '')
def get_answer(self, question: str):
res = self.rag_chat_agent(question)
if '```html' in res['output']:
html_res = self._get_html_content(res['output'])
else:
html_res = res['output']
return html_res
AASLD_CONVERSATIONAL_AGENT = AASLDConversationalAgent()
PAGE_TITLE_HTML = """
<h1> Arithmedics - AASLD Guidelines Chat Assistant</h1>
<h3> I am AASLD-GPT, trained on the AASLD guidelines for Hepatocellular Carcinoma (HCC). My role is to assist you in navigating and extracting information from these guidelines. As an early prototype, your patience and feedback are essential for my development and improvement. </h3>
"""
def get_empty_state():
return []
def clear_conversation():
return (gr.update(value=None, visible=True), None, '', get_empty_state(),
'', '')
def submit_question(question):
print(question)
try:
answer = AASLD_CONVERSATIONAL_AGENT.get_answer(question)
print(answer)
except Exception as e:
print(e)
answer = e['error']['message']
question = '<b>{}</b'.format(question)
return ('', question, answer)
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#arithmedics-img {width: 200px; display: block; margin: auto;}
#submitbtn {width: 50%; margin:auto}
#chatbox {min-height: 400px; .scroll-container {position: absolute; top: 0; right:0;}}
#header {text-align: center;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css, title='AASLD Practice Guidelines Chat Assistant',
) as demo:
# messages = gr.State(get_empty_state())
with gr.Column(elem_id='col-container'):
# gr.Image('Arithmedics_rectangle_logo.png',
# width=200,
# show_label=False,
# show_download_button=False,
# container=False,
# show_share_button=False,
# elem_id='arithmedics-img')
gr.HTML(PAGE_TITLE_HTML, elem_id='header')
with gr.Row():
with gr.Column(scale=7):
# btn_clear_conversation = gr.Button('πŸ”ƒ Start New Conversation')
input_message = gr.Textbox(
show_label=False,
placeholder='Enter your question',
visible=True, container=False)
btn_submit = gr.Button(
'Ask AASLD Guidelines',
variant='primary',
elem_id='submitbtn',
size='sm')
question = gr.HTML()
answer = gr.HTML()
# chatbot = gr.Chatbot(
# elem_id='chatbox', visible=False)
# total_tokens_str = gr.Markdown(elem_id='total_tokens_str')
btn_submit.click(
submit_question,
[input_message],
[input_message, question, answer]
)
input_message.submit(
submit_question,
[input_message],
[input_message, question, answer]
)
# btn_clear_conversation.click(
# clear_conversation, [],
# [input_message, messages, question, answer])
# demo.queue(concurrency_count=10)
if __name__ == "__main__":
demo.launch(height='800px', share=True)