YoniFriedman commited on
Commit
e2ed0fc
·
verified ·
1 Parent(s): 40b509d

initial app commit

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OPENAI_API_KEY"] = "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"
3
+
4
+ from llama_index.llms.openai import OpenAI
5
+ from llama_index.core.schema import MetadataMode
6
+ import openai
7
+ from openai import OpenAI as OpenAIOG
8
+ import logging
9
+ import sys
10
+ llm = OpenAI(temperature=0.0, model="gpt-4-turbo")
11
+ client = OpenAIOG()
12
+
13
+ from langdetect import detect
14
+ from langdetect import DetectorFactory
15
+ DetectorFactory.seed = 0
16
+ from deep_translator import GoogleTranslator
17
+
18
+ # Load index
19
+ from llama_index.core import VectorStoreIndex
20
+ from llama_index.core import StorageContext
21
+ from llama_index.core import load_index_from_storage
22
+ storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
23
+ index = load_index_from_storage(storage_context)
24
+ query_engine = index.as_query_engine(similarity_top_k=3, llm=llm)
25
+ retriever = index.as_retriever(similarity_top_k=3)
26
+
27
+ import gradio as gr
28
+
29
+ def nishauri(question: str, ccc_user: str, conversation_history: list[str]):
30
+
31
+ context = " ".join([item["user"] + " " + item["chatbot"] for item in conversation_history])
32
+
33
+ # Get patient info from DB
34
+ engine = create_engine('sqlite:///nishauri.db')
35
+
36
+ with engine.connect() as connection:
37
+ # Select data using a parameterized query
38
+ result = connection.execute(
39
+ text("SELECT visit_date, visit_type, regimen, viral_load FROM nishauri WHERE ccc_no = :ccc_no"),
40
+ {"ccc_no": ccc_user}
41
+ )
42
+
43
+ # Fetch and print results
44
+ row = result.fetchall()
45
+
46
+ last_appt = row[0][0]
47
+ appt_purpose = row[0][1]
48
+ regimen = row[0][2]
49
+ vl_result = row[0][3]
50
+
51
+
52
+ # Detect language of question - if Swahili, translate to English
53
+ # only do this if there are at least 5 words in the text, otherwise lang detection is unreliable
54
+
55
+ # Split the string into words
56
+ words = question.split()
57
+
58
+ # Count the number of words
59
+ num_words = len(words)
60
+
61
+ lang_question = "en"
62
+
63
+ if num_words > 4:
64
+ lang_question = detect(question)
65
+
66
+ # lang_question = detect(question)
67
+
68
+ if lang_question=="sw":
69
+ question = GoogleTranslator(source='sw', target='en').translate(question)
70
+
71
+ sources = retriever.retrieve(question)
72
+ source0 = sources[0].text
73
+ source1 = sources[1].text
74
+
75
+ background = ("The person who asked the question is a person living with HIV."
76
+ " If the person says sasa or niaje, that is swahili slang for hello. Just say hello back and ask how you can help."
77
+ " Recognize that they already have HIV and do not suggest that they have to get tested"
78
+ " for HIV or take post-exposure prophylaxis, as that is not relevant, though their partners perhaps should."
79
+ " Do not suggest anything that is not relevant to someone who already has HIV."
80
+ " Do not mention in the response that the person is living with HIV."
81
+ f" The person's last appointment was on {last_appt} and the purpose was {appt_purpose}. "
82
+ f" The person is on the following regimen for HIV: {regimen}. "
83
+ f" The person's most recent viral load result was {vl_result}. "
84
+ " The following information about viral loads is authoritative for any question about viral loads:"
85
+ " Under 50 copies/ml is low detectable level,"
86
+ " 50 - 199 copies/ml is low level viremia, 200 - 999 is high level viremia, and "
87
+ " 1000 and above is suspected treatment failure."
88
+ " A high viral load or non-suppressed viral load is any viral load above 200 copies/ml."
89
+ " A suppressed viral load is one below 200 copies / ml.")
90
+
91
+ question_final = (
92
+ f" The user previously asked and answered the following: {context}. "
93
+ f" The user just asked the following question: {question}."
94
+ f" Please use the following content to generate a response: {source0} {source1}."
95
+ f" The following background on the user should also inform the response as needed: {background}"
96
+ " Keep answers brief and limited to the question that was asked."
97
+ " Do not provide information the user did not ask about. If they start with a greeting, just greet them in return and don't share anything else."
98
+ )
99
+
100
+ completion = client.chat.completions.create(
101
+ model="gpt-4-turbo",
102
+ messages=[
103
+ {"role": "user", "content": question_final}
104
+ ]
105
+ )
106
+
107
+ reply_to_user = completion.choices[0].message.content
108
+
109
+
110
+ # If initial question was in Swahili, translate response back to Swahili
111
+ if lang_question=="sw":
112
+ reply_to_user = GoogleTranslator(source='auto', target='sw').translate(reply_to_user)
113
+
114
+ conversation_history.append({"user": question, "chatbot": reply_to_user})
115
+
116
+ return reply_to_user, conversation_history
117
+
118
+
119
+ demo = gr.Interface(
120
+ title = "Nishauri Chatbot Demo",
121
+ fn=nishauri,
122
+ inputs=[gr.Textbox(label="question", placeholder="Type your question here..."),
123
+ gr.Textbox(label="CCC", placeholder="Type your ccc here..."),
124
+ gr.State(value = [])],
125
+ outputs=["text", gr.State()],
126
+ )
127
+
128
+ demo.launch()