demiroz commited on
Commit
77c8a38
1 Parent(s): c20a007

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A chatbot that uses the LangChain and Gradio UI to answer medical questions."""
2
+ import os
3
+ from types import SimpleNamespace
4
+
5
+ import gradio as gr
6
+ import wandb
7
+ from chain import get_answer, load_chain, load_vector_store
8
+ from config import default_config
9
+
10
+
11
+ class Chat:
12
+ """A chatbot interface that persists the vectorstore and chain between calls."""
13
+
14
+ def __init__(
15
+ self,
16
+ config: SimpleNamespace,
17
+ ):
18
+ """Initialize the chatbot.
19
+ Args:
20
+ config (SimpleNamespace): The configuration.
21
+ """
22
+ self.config = config
23
+ self.wandb_run = wandb.init(
24
+ project=self.config.project,
25
+ entity=self.config.entity,
26
+ job_type=self.config.job_type,
27
+ config=self.config,
28
+ )
29
+ self.vector_store = None
30
+ self.chain = None
31
+
32
+ def __call__(
33
+ self,
34
+ question: str,
35
+ history: list[tuple[str, str]] | None = None,
36
+ openai_api_key: str = None,
37
+ ):
38
+ """Answer a question about medical issues using the LangChain QA chain and vector store retriever.
39
+ Args:
40
+ question (str): The question to answer.
41
+ history (list[tuple[str, str]] | None, optional): The chat history. Defaults to None.
42
+ openai_api_key (str, optional): The OpenAI API key. Defaults to None.
43
+ Returns:
44
+ list[tuple[str, str]], list[tuple[str, str]]: The chat history before and after the question is answered.
45
+ """
46
+ if openai_api_key is not None:
47
+ openai_key = openai_api_key
48
+ elif os.environ["OPENAI_API_KEY"]:
49
+ openai_key = os.environ["OPENAI_API_KEY"]
50
+ else:
51
+ raise ValueError(
52
+ "Please provide your OpenAI API key as an argument or set the OPENAI_API_KEY environment variable"
53
+ )
54
+
55
+ if self.vector_store is None:
56
+ self.vector_store = load_vector_store(
57
+ wandb_run=self.wandb_run, openai_api_key=openai_key
58
+ )
59
+ if self.chain is None:
60
+ self.chain = load_chain(
61
+ self.wandb_run, self.vector_store, openai_api_key=openai_key
62
+ )
63
+
64
+ history = history or []
65
+ question = question.lower()
66
+ response = get_answer(
67
+ chain=self.chain,
68
+ question=question,
69
+ chat_history=history,
70
+ )
71
+ history.append((question, response))
72
+ return history, history
73
+
74
+
75
+ with gr.Blocks() as demo:
76
+ gr.HTML(
77
+ """<div style="text-align: center; max-width: 700px; margin: 0 auto;">
78
+ <div
79
+ style="
80
+ display: inline-flex;
81
+ align-items: center;
82
+ gap: 0.8rem;
83
+ font-size: 1.75rem;
84
+ "
85
+ >
86
+ <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
87
+ Virtual Medical Assistant
88
+ </h1>
89
+ </div>
90
+ <p style="margin-bottom: 10px; font-size: 94%">
91
+ Hi, I'm a virtual medical assistant that can answer your medical questions. Please start by typing in your OpenAI API key and medical questions/issues.<br>
92
+ Built using <a href="https://langchain.readthedocs.io/en/latest/" target="_blank">LangChain</a> and <a href="https://github.com/gradio-app/gradio" target="_blank">Gradio Github repo</a>
93
+ </p>
94
+ </div>"""
95
+ )
96
+ with gr.Row():
97
+ question = gr.Textbox(
98
+ label="Type in your medical questions here and press Enter!",
99
+ placeholder="What is diabetes?",
100
+ )
101
+ openai_api_key = gr.Textbox(
102
+ type="password",
103
+ label="Enter your OpenAI API key here",
104
+ )
105
+ state = gr.State()
106
+ chatbot = gr.Chatbot()
107
+ question.submit(
108
+ Chat(
109
+ config=default_config,
110
+ ),
111
+ [question, state, openai_api_key],
112
+ [chatbot, state],
113
+ )
114
+
115
+
116
+ if __name__ == "__main__":
117
+ demo.queue().launch(
118
+ share=False, server_name="0.0.0.0", server_port=8884, show_error=True
119
+ )