Spaces:
Runtime error
Runtime error
File size: 6,387 Bytes
3c9ee2e 569e93f 05a8056 fe648aa 569e93f cf58c38 3c9ee2e 493538b 569e93f 3c9ee2e 569e93f c4e8ff0 569e93f 75d1596 569e93f b462580 7906531 75d1596 569e93f c5960fe 569e93f f423da9 569e93f 6dcbc0e c5960fe 569e93f 4aeb33c 569e93f f4c0fc7 fe648aa a2fd386 11495a7 5e548ee f4c0fc7 569e93f 22cf28c f9437a5 569e93f 4f60678 c5960fe f4c0fc7 7016312 06d8a15 7016312 96660e9 7016312 54f0768 5e548ee 96660e9 569e93f f4c0fc7 96660e9 ef5cdc9 b392e35 96660e9 ed21381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import os
import openai
import requests
import csv
import faiss
import tiktoken
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.memory import ChatMessageHistory
from langchain.schema.messages import SystemMessage
from sentence_transformers import SentenceTransformer, util
LLM_MODEL = 'gpt-4-1106-preview'
OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY')
class AASLDConversationalAgent():
def __init__(self) -> None:
# loader = PyPDFLoader(
# 'docs/aasld_practice_guidance_on_prevention_diagnosis441.pdf')
loader = PyPDFLoader(
'aasld_practice_guidance_on_prevention_diagnosis441.pdf')
pages = loader.load_and_split()
embeddings_model = OpenAIEmbeddings(openai_api_key=OPEN_AI_KEY)
# embeddings_model = HuggingFaceEmbeddings(
# model_name='sentence-transformers/all-mpnet-base-v2',
# model_kwargs={'device': 'cpu'},
# encode_kwargs={'normalize_embeddings': False}
# )
self.doc_search = FAISS.from_documents(pages, embeddings_model)
self.doc_retriever = self.doc_search.as_retriever(
search_type="mmr", search_kwargs={"k": 10})
self.llm = ChatOpenAI(temperature = 0, model_name=LLM_MODEL,
openai_api_key=OPEN_AI_KEY)
tool = create_retriever_tool(
self.doc_retriever,
"SearchAASLDDocs",
"""Searches and returns documents regarding the AASLD Practice Guidance on
Prevention, Diagnosis, and Treatment of Hepatocellular Carcinoma.""",
)
self.tools = [tool]
system_message = SystemMessage (
content="""
You are a chat assistant who searches the document for the given
question and returns the summary of the results as a HTML body with
formatting. The results text should be left aligned.
If the document does not have the answers please let the users know
it is unavailable and do not try to answer.
"""
)
self.rag_chat_agent = create_conversational_retrieval_agent(
self.llm, self.tools, system_message=system_message,
verbose=False,
remember_intermediate_steps=False, max_token_limit=8000)
def _get_html_content(self, res:str):
print(res)
return res.split('```html')[1].replace('```', '')
def get_answer(self, question: str):
res = self.rag_chat_agent(question)
if '```html' in res['output']:
html_res = self._get_html_content(res['output'])
else:
html_res = res['output']
return html_res
AASLD_CONVERSATIONAL_AGENT = AASLDConversationalAgent()
PAGE_TITLE_HTML = """
<h1> Arithmedics - AASLD Guidelines Chat Assistant</h1>
<h3> I am AASLD-GPT, trained on the AASLD guidelines for Hepatocellular Carcinoma (HCC). My role is to assist you in navigating and extracting information from these guidelines. As an early prototype, your patience and feedback are essential for my development and improvement. </h3>
"""
def get_empty_state():
return []
def clear_conversation():
return (gr.update(value=None, visible=True), None, '', get_empty_state(),
'', '')
def submit_question(question):
print(question)
try:
answer = AASLD_CONVERSATIONAL_AGENT.get_answer(question)
print(answer)
except Exception as e:
print(e)
answer = e['error']['message']
question = '<b>{}</b'.format(question)
return ('', question, answer)
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#arithmedics-img {width: 200px; display: block; margin: auto;}
#submitbtn {width: 50%; margin:auto}
#chatbox {min-height: 400px; .scroll-container {position: absolute; top: 0; right:0;}}
#header {text-align: center;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css, title='AASLD Practice Guidelines Chat Assistant',
) as demo:
# messages = gr.State(get_empty_state())
with gr.Column(elem_id='col-container'):
# gr.Image('Arithmedics_rectangle_logo.png',
# width=200,
# show_label=False,
# show_download_button=False,
# container=False,
# show_share_button=False,
# elem_id='arithmedics-img')
gr.HTML(PAGE_TITLE_HTML, elem_id='header')
with gr.Row():
with gr.Column(scale=7):
# btn_clear_conversation = gr.Button('🔃 Start New Conversation')
input_message = gr.Textbox(
show_label=False,
placeholder='Enter your question',
visible=True, container=False)
btn_submit = gr.Button(
'Ask AASLD Guidelines',
variant='primary',
elem_id='submitbtn',
size='sm')
question = gr.HTML()
answer = gr.HTML()
# chatbot = gr.Chatbot(
# elem_id='chatbox', visible=False)
# total_tokens_str = gr.Markdown(elem_id='total_tokens_str')
btn_submit.click(
submit_question,
[input_message],
[input_message, question, answer]
)
input_message.submit(
submit_question,
[input_message],
[input_message, question, answer]
)
# btn_clear_conversation.click(
# clear_conversation, [],
# [input_message, messages, question, answer])
# demo.queue(concurrency_count=10)
if __name__ == "__main__":
demo.launch(height='800px', share=True) |