File size: 6,387 Bytes
3c9ee2e
569e93f
 
 
 
 
 
 
 
 
 
 
 
05a8056
fe648aa
569e93f
 
 
 
cf58c38
 
3c9ee2e
493538b
569e93f
3c9ee2e
569e93f
 
 
 
 
c4e8ff0
569e93f
75d1596
 
 
 
 
 
569e93f
b462580
7906531
75d1596
 
569e93f
 
 
 
 
 
 
 
 
 
 
c5960fe
 
 
569e93f
 
 
 
 
 
 
 
 
f423da9
569e93f
 
 
 
 
 
 
 
 
 
 
 
6dcbc0e
c5960fe
569e93f
 
 
4aeb33c
569e93f
 
 
 
 
f4c0fc7
fe648aa
a2fd386
 
 
 
 
11495a7
5e548ee
f4c0fc7
569e93f
 
 
22cf28c
f9437a5
569e93f
 
 
 
 
 
 
4f60678
c5960fe
f4c0fc7
7016312
06d8a15
 
 
 
 
 
 
7016312
 
 
96660e9
7016312
 
 
 
54f0768
 
 
 
 
5e548ee
96660e9
 
 
 
569e93f
 
 
f4c0fc7
 
 
 
 
 
 
 
96660e9
 
 
ef5cdc9
b392e35
96660e9
ed21381
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import os
import openai
import requests
import csv
import faiss
import tiktoken
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.memory import ChatMessageHistory
from langchain.schema.messages import SystemMessage
from sentence_transformers import SentenceTransformer, util


LLM_MODEL = 'gpt-4-1106-preview'
OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY')

class AASLDConversationalAgent():
    def __init__(self) -> None:
        # loader = PyPDFLoader(
        # 'docs/aasld_practice_guidance_on_prevention_diagnosis441.pdf')
        loader = PyPDFLoader(
            'aasld_practice_guidance_on_prevention_diagnosis441.pdf')
        pages = loader.load_and_split()
        embeddings_model = OpenAIEmbeddings(openai_api_key=OPEN_AI_KEY)
        # embeddings_model = HuggingFaceEmbeddings(
        #     model_name='sentence-transformers/all-mpnet-base-v2',
        #     model_kwargs={'device': 'cpu'},
        #     encode_kwargs={'normalize_embeddings': False}
        #     )
        self.doc_search = FAISS.from_documents(pages, embeddings_model)
        self.doc_retriever = self.doc_search.as_retriever(
            search_type="mmr", search_kwargs={"k": 10})
        self.llm = ChatOpenAI(temperature = 0, model_name=LLM_MODEL, 
                              openai_api_key=OPEN_AI_KEY)
        tool = create_retriever_tool(
            self.doc_retriever,
            "SearchAASLDDocs",
            """Searches and returns documents regarding the AASLD Practice Guidance on
            Prevention, Diagnosis, and Treatment of Hepatocellular Carcinoma.""",
        )
        self.tools = [tool]
        system_message = SystemMessage (
            content="""
            You are a chat assistant who searches the document for the given 
            question and returns the summary of the results as a HTML body with
            formatting. The results text should be left aligned. 
            If the document does not have the answers please let the users know 
            it is unavailable and do not try to answer.
            """
        )
        self.rag_chat_agent = create_conversational_retrieval_agent(
            self.llm, self.tools, system_message=system_message,
            verbose=False,
            remember_intermediate_steps=False, max_token_limit=8000)

    def _get_html_content(self, res:str):
        print(res)
        return res.split('```html')[1].replace('```', '')
        
    def get_answer(self, question: str):
        res = self.rag_chat_agent(question)
        if '```html' in res['output']:
            html_res = self._get_html_content(res['output'])
        else:
            html_res = res['output']
        return html_res

AASLD_CONVERSATIONAL_AGENT = AASLDConversationalAgent()

PAGE_TITLE_HTML = """
<h1> Arithmedics - AASLD Guidelines Chat Assistant</h1> 
<h3> I am AASLD-GPT, trained on the AASLD guidelines for Hepatocellular Carcinoma (HCC). My role is to assist you in navigating and extracting information from these guidelines. As an early prototype, your patience and feedback are essential for my development and improvement. </h3>
"""

def get_empty_state():
    return []

def clear_conversation():
    return (gr.update(value=None, visible=True), None, '', get_empty_state(),
            '', '')

def submit_question(question):
    print(question)
    try:
        answer = AASLD_CONVERSATIONAL_AGENT.get_answer(question)
        print(answer)
    except Exception as e:
        print(e)
        answer = e['error']['message']
    question = '<b>{}</b'.format(question)
    return ('', question, answer)

css = """
      #col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
      #arithmedics-img {width: 200px; display: block; margin: auto;}
      #submitbtn {width: 50%; margin:auto}
      #chatbox {min-height: 400px; .scroll-container {position: absolute; top: 0; right:0;}}
      #header {text-align: center;}
      #total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
      #label {font-size: 0.8em; padding: 0.5em; margin: 0;}
      .message { font-size: 1.2em; }
      """

with gr.Blocks(css=css, title='AASLD Practice Guidelines Chat Assistant', 
               ) as demo:
    # messages = gr.State(get_empty_state())
    with gr.Column(elem_id='col-container'):
        # gr.Image('Arithmedics_rectangle_logo.png', 
        #          width=200,
        #          show_label=False,
        #          show_download_button=False,
        #          container=False,
        #          show_share_button=False,
        #          elem_id='arithmedics-img')
        gr.HTML(PAGE_TITLE_HTML, elem_id='header')
        with gr.Row():
            with gr.Column(scale=7):
                # btn_clear_conversation = gr.Button('🔃 Start New Conversation')
                input_message = gr.Textbox(
                    show_label=False, 
                    placeholder='Enter your question',
                    visible=True, container=False)
                btn_submit = gr.Button(
                    'Ask AASLD Guidelines',
                    variant='primary',
                    elem_id='submitbtn',
                    size='sm')
                question = gr.HTML()
                answer = gr.HTML()
                # chatbot = gr.Chatbot(
                #     elem_id='chatbox', visible=False)
                # total_tokens_str = gr.Markdown(elem_id='total_tokens_str')

    btn_submit.click(
        submit_question,
        [input_message],
        [input_message, question, answer]
    )
    input_message.submit(
        submit_question, 
        [input_message],
        [input_message, question, answer]
    )
    # btn_clear_conversation.click(
    #     clear_conversation, [], 
    #     [input_message, messages, question, answer])
    
# demo.queue(concurrency_count=10)
if __name__ == "__main__":
    demo.launch(height='800px', share=True)