|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
from langchain.llms import OpenAI |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.document_loaders import PyPDFLoader, Docx2txtLoader |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from langchain.memory import ConversationSummaryMemory, ConversationBufferMemory |
|
from langchain.prompts import ( |
|
ChatPromptTemplate, |
|
MessagesPlaceholder, |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
) |
|
|
|
import gradio as gr |
|
import datetime |
|
from huggingface_hub import Repository |
|
from datasets import load_dataset |
|
import random |
|
import string |
|
from pyairtable import Api |
|
|
|
os.environ["OPENAI_API_KEY"] |
|
os.environ["HUB_TOKEN"] |
|
os.environ["AIR_TOKEN"] |
|
|
|
|
|
|
|
repo = Repository( |
|
local_dir="private", |
|
repo_type="dataset", |
|
clone_from="https://huggingface.co/datasets/akellyirl/private_T2DM", |
|
token=os.environ["HUB_TOKEN"] |
|
) |
|
repo.git_pull() |
|
|
|
|
|
def find_and_read_topics(base_path): |
|
topics_list = [] |
|
|
|
for dirpath, dirnames, filenames in os.walk(base_path): |
|
if "topic.txt" in filenames: |
|
with open(os.path.join(dirpath, "topic.txt"), "r", encoding="utf-8") as file: |
|
topic = file.read().strip() |
|
topics_list.append((dirpath, topic)) |
|
|
|
return topics_list |
|
|
|
|
|
base_directory = "./private/docs" |
|
topics = find_and_read_topics(base_directory) |
|
for directory, topic_content in topics: |
|
print(f"Directory: {directory}\nTopic Content: {topic_content}\n") |
|
|
|
|
|
select = 0 |
|
|
|
dir = topics[select][0] |
|
topic = topics[select][1] |
|
|
|
|
|
files = [] |
|
for foldername, subfolders, filenames in os.walk(dir): |
|
for filename in filenames: |
|
if filename.endswith(('.pdf','.PDF')): |
|
|
|
path = os.path.join(foldername, filename) |
|
|
|
if os.path.isfile(path): |
|
files.append(path) |
|
else: |
|
print(f"{path} is not a valid path.") |
|
|
|
print(f'{len(files)} files') |
|
|
|
print(files) |
|
|
|
|
|
|
|
|
|
loaders = [PyPDFLoader(file) for file in files] |
|
|
|
|
|
data = [] |
|
for loader in loaders: |
|
data += loader.load_and_split() |
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) |
|
all_splits = text_splitter.split_documents(data) |
|
|
|
|
|
vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings()) |
|
|
|
|
|
model = 'gpt-3.5-turbo-0125' |
|
|
|
|
|
llm = ChatOpenAI(model=model, temperature=0) |
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
def predict(message, history): |
|
|
|
system_template = r""" |
|
- You are a health education chatbot for type 2 diabetes (T2DM) patients. |
|
- You only discuss the documents provided and information related to the them. |
|
- Your goal is to improve health management of chronic diseases |
|
- You also need to help users explore their attitudes to their health in relation to T2DM |
|
- Your answers should explain things clearly and avoid jargon. |
|
- You are allowed to chat with the user in general conversation to support your goal. |
|
- If the user goes off topic, gently and politely let them know and go back on topic. |
|
- You must be safe to use. If you don't know the answer then say that. Do not make anything up. |
|
- Always finish by establishing where you found the information in the documents provided, including the document, the section and page number. |
|
- If you did not find the information in the documents provided, then say "this information was sourced from my general knowledge." |
|
---- |
|
{context} |
|
---- |
|
""" |
|
|
|
user_template = "Question:```{question}```" |
|
|
|
qa_prompt = ChatPromptTemplate.from_messages([SystemMessagePromptTemplate.from_template(system_template), |
|
HumanMessagePromptTemplate.from_template(user_template)]) |
|
|
|
qa = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, combine_docs_chain_kwargs={"prompt": qa_prompt}) |
|
|
|
chat_history = [] |
|
for h1, h2 in history: |
|
chat_history.append((h1,h2)) |
|
|
|
ans = qa({"question":message, "chat_history": chat_history})['answer'] |
|
|
|
history.append((message, ans)) |
|
|
|
return "", history |
|
|
|
def add_log_entry(SessionID): |
|
'''Add an entry to Airtable''' |
|
|
|
|
|
base_id = 'appZ2hdiBsz6R76zm' |
|
table_name = 't2dm_log' |
|
api = Api(os.environ['AIR_TOKEN']) |
|
airtable = api.table(base_id, table_name) |
|
|
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
record = {'session_id': SessionID,'datetime': current_time} |
|
airtable.create(record) |
|
return None |
|
|
|
|
|
def generate_session_id(length=10): |
|
|
|
characters = string.ascii_letters + string.digits |
|
return user_id+'_T2DM_'+''.join(random.choice(characters) for _ in range(length)) |
|
|
|
def initialize_id(): |
|
return generate_session_id() |
|
|
|
|
|
user_id = "" |
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as chat: |
|
|
|
|
|
session_id = gr.Textbox(label="Session ID", value=initialize_id, interactive=False, visible=False) |
|
|
|
gr.Markdown(f"""# I am a customised AI chatbot for {topic}. |
|
<i>Running {model}. NOTE: On rare occasions I can take up to 60 seconds to respond. |
|
If I'm taking too long, please refresh the page and continue.""") |
|
|
|
chatbot = gr.Chatbot(height=300, show_copy_button = False, show_share_button = False) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox(placeholder="Type here >> ", container=False, scale=10, min_width=250) |
|
submit = gr.Button(value="Submit", variant="primary", scale=1, min_width=20) |
|
|
|
with gr.Row(): |
|
report = gr.Button(value="REPORT", variant="secondary", |
|
link="https://padlet.com/akellyirl/strathbot-flagging-2b4ko3rhk94wja6e") |
|
clear = gr.ClearButton([msg, chatbot]) |
|
|
|
examples=(["What are the main topics?","Explain this very simply", |
|
"Suggest a topic","Tell me more about that", |
|
"Provide more reading"]) |
|
|
|
def on_select(ex): |
|
return ex |
|
|
|
gr.Markdown("#### *Examples:*") |
|
ex = {} |
|
with gr.Group("Examples"): |
|
with gr.Row(): |
|
for ind, exa in enumerate(examples): |
|
ex[ind] = gr.Textbox(exa, container=False, interactive=True) |
|
ex[ind].focus(fn=on_select, inputs=ex[ind], outputs=msg) |
|
|
|
|
|
gr.on(triggers=[msg.submit, submit.click], |
|
fn= predict, inputs=[msg, chatbot], outputs=[msg, chatbot], |
|
concurrency_limit = 100,).then(add_log_entry, session_id) |
|
|
|
sessionID = generate_session_id() |
|
|
|
chat.launch() |