|
from typing import Any, List, Tuple |
|
import gradio as gr |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_openai import ChatOpenAI |
|
from langchain_community.document_loaders import PyMuPDFLoader |
|
import fitz |
|
from PIL import Image |
|
import os |
|
import openai |
|
|
|
|
|
class MyApp: |
|
def __init__(self) -> None: |
|
self.OPENAI_API_KEY: str = None |
|
self.chain = None |
|
self.chat_history: list = [] |
|
self.documents = None |
|
self.file_name = None |
|
|
|
def set_api_key(self, api_key: str): |
|
self.OPENAI_API_KEY = api_key |
|
openai.api_key = api_key |
|
|
|
def process_file(self, file) -> Image.Image: |
|
loader = PyMuPDFLoader(file.name) |
|
self.documents = loader.load() |
|
self.file_name = os.path.basename(file.name) |
|
doc = fitz.open(file.name) |
|
page = doc[0] |
|
pix = page.get_pixmap(dpi=150) |
|
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return image |
|
|
|
def build_chain(self, file) -> str: |
|
embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY) |
|
pdfsearch = Chroma.from_documents( |
|
self.documents, |
|
embeddings, |
|
collection_name=self.file_name, |
|
) |
|
self.chain = ConversationalRetrievalChain.from_llm( |
|
ChatOpenAI(temperature=0.0, openai_api_key=self.OPENAI_API_KEY), |
|
retriever=pdfsearch.as_retriever(search_kwargs={"k": 1}), |
|
return_source_documents=True, |
|
) |
|
return "Vector database built successfully!" |
|
|
|
|
|
def add_text(history: List[Tuple[str, str]], text: str) -> List[Tuple[str, str]]: |
|
if not text: |
|
raise gr.Error("Enter text") |
|
history.append((text, "")) |
|
return history |
|
|
|
|
|
def get_response(history, query): |
|
if app.chain is None: |
|
raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.") |
|
|
|
try: |
|
result = app.chain.invoke( |
|
{"question": query, "chat_history": app.chat_history} |
|
) |
|
app.chat_history.append((query, result["answer"])) |
|
source_docs = result["source_documents"] |
|
source_texts = [] |
|
for doc in source_docs: |
|
source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}") |
|
source_texts_str = "\n\n".join(source_texts) |
|
history[-1] = (history[-1][0], result["answer"]) |
|
return history, source_texts_str |
|
except Exception as e: |
|
app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!")) |
|
return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}" |
|
|
|
|
|
def get_response_current(history, query): |
|
if app.chain is None: |
|
raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.") |
|
|
|
try: |
|
result = app.chain.invoke( |
|
{"question": query, "chat_history": app.chat_history} |
|
) |
|
app.chat_history.append((query, result["answer"])) |
|
source_docs = result["source_documents"] |
|
source_texts = [] |
|
for doc in source_docs: |
|
source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}") |
|
source_texts_str = "\n\n".join(source_texts) |
|
history[-1] = (history[-1][0], result["answer"]) |
|
return history, source_texts_str |
|
except Exception as e: |
|
app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!")) |
|
return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}" |
|
|
|
|
|
def render_file(file) -> Image.Image: |
|
doc = fitz.open(file.name) |
|
page = doc[0] |
|
pix = page.get_pixmap(dpi=150) |
|
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return image |
|
|
|
|
|
def purge_chat_and_render_first(file) -> Image.Image: |
|
app.chat_history = [] |
|
doc = fitz.open(file.name) |
|
page = doc[0] |
|
pix = page.get_pixmap(dpi=150) |
|
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return image |
|
|
|
|
|
def refresh_chat(): |
|
app.chat_history = [] |
|
return [] |
|
|
|
app = MyApp() |
|
|
|
|
|
def set_api_key(api_key): |
|
app.set_api_key(api_key) |
|
|
|
saved_file_path = "THEDIA1.pdf" |
|
with open(saved_file_path, 'rb') as saved_file: |
|
app.process_file(saved_file) |
|
app.build_chain(saved_file) |
|
return f"API Key set to {api_key[:4]}...{api_key[-4:]} and vector database built successfully!" |
|
|
|
|
|
questions = [ |
|
"What is the primary goal of Dialectical Behaviour Therapy?", |
|
"How can mindfulness help in managing emotions?", |
|
"What are some techniques to handle distressing situations?", |
|
"Can you explain the concept of radical acceptance?", |
|
"How does DBT differ from other types of therapy?" |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("🧘♀️ **Dialectical Behaviour Therapy**") |
|
gr.Markdown( |
|
"Disclaimer: This chatbot is based on a DBT exercise book that is publicly available. " |
|
"We are not medical practitioners, and the use of this chatbot is at your own responsibility." |
|
) |
|
|
|
api_key_input = gr.Textbox(label="OpenAI API Key", type="password", placeholder="Enter your OpenAI API Key") |
|
api_key_btn = gr.Button("Set API Key") |
|
api_key_status = gr.Textbox(value="API Key status", interactive=False) |
|
|
|
api_key_btn.click( |
|
fn=set_api_key, |
|
inputs=[api_key_input], |
|
outputs=[api_key_status] |
|
) |
|
|
|
with gr.Tab("Take a Dialectical Behaviour Therapy with Me"): |
|
with gr.Column(): |
|
chatbot_current = gr.Chatbot(elem_id="chatbot_current") |
|
txt_current = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter text and press submit", |
|
scale=2 |
|
) |
|
submit_btn_current = gr.Button("Submit", scale=1) |
|
refresh_btn_current = gr.Button("Refresh Chat", scale=1) |
|
source_texts_output_current = gr.Textbox(label="Source Texts", interactive=False) |
|
|
|
submit_btn_current.click( |
|
fn=add_text, |
|
inputs=[chatbot_current, txt_current], |
|
outputs=[chatbot_current], |
|
queue=False, |
|
).success( |
|
fn=get_response_current, inputs=[chatbot_current, txt_current], outputs=[chatbot_current, source_texts_output_current] |
|
) |
|
|
|
refresh_btn_current.click( |
|
fn=refresh_chat, |
|
inputs=[], |
|
outputs=[chatbot_current], |
|
) |
|
|
|
with gr.Tab("Questions"): |
|
gr.Markdown("### Example Questions") |
|
for question in questions: |
|
gr.Markdown(f"- {question}") |
|
|
|
demo.queue() |
|
demo.launch() |
|
|