t2dm_chat / app.py
akellyirl's picture
Update app.py
1dca20a verified
# This work © 2024 by Anthony Kelly is licensed under CC BY-NC-SA 4.0.
# To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To attribute, Please cite:
# Kelly, A., Noctor, E., & Van de Ven, P. (2024, June 2-5).
# Design, architecture, and safety evaluation of an AI chatbot for an educational approach to health promotion in chronic medical conditions [Conference presentation].
# 12th Annual Meeting of the International Society for Research on Internet Interventions (ISRII), Limerick, Ireland.
import os
#os.system("pip uninstall -y gradio")
#os.system("pip install gradio==4.36.0")
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader, Docx2txtLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationSummaryMemory, ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
import gradio as gr
import datetime
from huggingface_hub import Repository
from datasets import load_dataset
import random
import string
from pyairtable import Api
os.environ["OPENAI_API_KEY"]
os.environ["HUB_TOKEN"]
os.environ["AIR_TOKEN"]
# Pull Lesson docs from dataset repo for privacy
repo = Repository(
local_dir="private",
repo_type="dataset",
clone_from="https://huggingface.co/datasets/akellyirl/private_T2DM",
token=os.environ["HUB_TOKEN"]
)
repo.git_pull()
# Scan the directories : if a 'topic.txt' file exists then it's considered a valid directory
def find_and_read_topics(base_path):
topics_list = []
for dirpath, dirnames, filenames in os.walk(base_path):
if "topic.txt" in filenames:
with open(os.path.join(dirpath, "topic.txt"), "r", encoding="utf-8") as file:
topic = file.read().strip()
topics_list.append((dirpath, topic))
return topics_list
# Lesson docs pulled from repo for privacy
base_directory = "./private/docs"
topics = find_and_read_topics(base_directory)
for directory, topic_content in topics:
print(f"Directory: {directory}\nTopic Content: {topic_content}\n")
# Select Topic
select = 0 # <=========
dir = topics[select][0]
topic = topics[select][1]
# Scan select directory for pdf files
files = []
for foldername, subfolders, filenames in os.walk(dir):
for filename in filenames:
if filename.endswith(('.pdf','.PDF')):
# Construct full file path and append to pdf_files list
path = os.path.join(foldername, filename)
if os.path.isfile(path):
files.append(path)
else:
print(f"{path} is not a valid path.")
print(f'{len(files)} files')
print(files)
# https://python.langchain.com/docs/use_cases/question_answering/how_to/chat_vector_db
# Create an instance of PyPDFLoader for each PDF file
loaders = [PyPDFLoader(file) for file in files]
# Load and split the PDFs into individual documents
data = []
for loader in loaders:
data += loader.load_and_split()
# SPLIT
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
# STORE
vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())
# Chat Model
model = 'gpt-3.5-turbo-0125'
#model = 'gpt-4'
llm = ChatOpenAI(model=model, temperature=0)
retriever = vectorstore.as_retriever()
def predict(message, history):
system_template = r"""
- You are a health education chatbot for type 2 diabetes (T2DM) patients.
- You only discuss the documents provided and information related to the them.
- Your goal is to improve health management of chronic diseases
- You also need to help users explore their attitudes to their health in relation to T2DM
- Your answers should explain things clearly and avoid jargon.
- You are allowed to chat with the user in general conversation to support your goal.
- If the user goes off topic, gently and politely let them know and go back on topic.
- You must be safe to use. If you don't know the answer then say that. Do not make anything up.
- Always finish by establishing where you found the information in the documents provided, including the document, the section and page number.
- If you did not find the information in the documents provided, then say "this information was sourced from my general knowledge."
----
{context}
----
"""
user_template = "Question:```{question}```"
qa_prompt = ChatPromptTemplate.from_messages([SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template(user_template)])
qa = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, combine_docs_chain_kwargs={"prompt": qa_prompt})
chat_history = []
for h1, h2 in history:
chat_history.append((h1,h2))
ans = qa({"question":message, "chat_history": chat_history})['answer']
history.append((message, ans))
return "", history
def add_log_entry(SessionID):
'''Add an entry to Airtable'''
# Connect to Airtable for logging
base_id = 'appZ2hdiBsz6R76zm'
table_name = 't2dm_log'
api = Api(os.environ['AIR_TOKEN'])
airtable = api.table(base_id, table_name)
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
record = {'session_id': SessionID,'datetime': current_time}
airtable.create(record)
return None
def generate_session_id(length=10):
# Generate a random alphanumeric string
characters = string.ascii_letters + string.digits
return user_id+'_T2DM_'+''.join(random.choice(characters) for _ in range(length))
def initialize_id():
return generate_session_id()
user_id = ""
with gr.Blocks(theme=gr.themes.Default()) as chat:
# Generate a unique Session ID
session_id = gr.Textbox(label="Session ID", value=initialize_id, interactive=False, visible=False)
gr.Markdown(f"""# I am a customised AI chatbot for {topic}.
<i>Running {model}. NOTE: On rare occasions I can take up to 60 seconds to respond.
If I'm taking too long, please refresh the page and continue.""")
chatbot = gr.Chatbot(height=300, show_copy_button = False, show_share_button = False)
with gr.Row():
msg = gr.Textbox(placeholder="Type here >> ", container=False, scale=10, min_width=250)
submit = gr.Button(value="Submit", variant="primary", scale=1, min_width=20)
with gr.Row():
report = gr.Button(value="REPORT", variant="secondary",
link="https://padlet.com/akellyirl/strathbot-flagging-2b4ko3rhk94wja6e")
clear = gr.ClearButton([msg, chatbot])
examples=(["What are the main topics?","Explain this very simply",
"Suggest a topic","Tell me more about that",
"Provide more reading"])
def on_select(ex):
return ex
gr.Markdown("#### *Examples:*")
ex = {}
with gr.Group("Examples"):
with gr.Row():
for ind, exa in enumerate(examples):
ex[ind] = gr.Textbox(exa, container=False, interactive=True)
ex[ind].focus(fn=on_select, inputs=ex[ind], outputs=msg)
# Submit on Enter or Button click
gr.on(triggers=[msg.submit, submit.click],
fn= predict, inputs=[msg, chatbot], outputs=[msg, chatbot],
concurrency_limit = 100,).then(add_log_entry, session_id)
sessionID = generate_session_id()
chat.launch()