beitrag-service / app.py
muhtasham's picture
Update app.py
0aac03a verified
raw
history blame
No virus
6.35 kB
import os
import tempfile
import gradio as gr
import torch
import logging
import base64
from operator import itemgetter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_community.vectorstores.chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_debug
from dotenv import load_dotenv
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
# configure logging
logging.basicConfig(level=logging.INFO)
set_debug(True)
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
langchain_api_key = os.getenv("LANGCHAIN_API_KEY")
langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT")
langchain_project_id = os.getenv("LANGCHAIN_PROJECT")
access_key = os.getenv("ACCESS_TOKEN_SECRET")
persist_dir = "./chroma_db"
device = 'cuda:0'
model_name = "all-mpnet-base-v2"
model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"}
logging.info(f"Using device {model_kwargs['device']}")
embed_money = False
# Create embeddings and store in vectordb
if embed_money:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
logging.info(f"Using OpenAI embeddings")
else:
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
logging.info(f"Using HuggingFace embeddings")
def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500):
logging.info("Configuring retriever")
if not os.path.exists(persist_dir):
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.")
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for filename in local_files:
logging.info(f"Reading file {filename}")
# Read the file once
if not os.path.exists(os.path.join("docs", filename)):
file_content = open(os.path.join(".", filename), "rb").read()
else:
file_content = open(os.path.join("docs", filename), "rb").read()
temp_filepath = os.path.join(temp_dir.name, filename)
with open(temp_filepath, "wb") as f:
f.write(file_content)
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
return retriever
else:
logging.info(f"Persist directory {persist_dir} exists. Loading from it.")
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
return retriever
directory = "docs" if os.path.exists("docs") else "."
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
# Setup LLM
llm = ChatOpenAI(
model_name="gpt-4-0125-preview", openai_api_key=openai_api_key, temperature=0.1, streaming=True
)
retriever = configure_retriever(local_files)
template = """Answer the question based only on the following context:
{context}
Question: {question}
Chat History: {history}
Answer in German Language. If the question is not related to the context, answer with "I don't know".
If the user is asking for follow-up questions on the same topic, generate different questions than you already answered.
"""
prompt = ChatPromptTemplate.from_template(template)
chain_translate = (
llm
| StrOutputParser()
)
chain_rag = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
"history": itemgetter("history")
}
| prompt
| llm
| StrOutputParser()
)
def predict(message, history):
message = chain_translate.invoke(f"Translate this query to English if it is in German otherwise return original contetn: {message}")
history_langchain_format = []
partial_message = ""
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))
for response in chain_rag.stream({"question": message, "history": history_langchain_format}):
partial_message += response
yield partial_message
image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png"
logo_base64 = image_to_base64(image_path)
# CSS with the Base64-encoded image
css = f"""
body::before {{
content: '';
display: block;
height: 150px !important; /* Adjust based on your logo's size */
background: url('data:image/png;base64,{logo_base64}') no-repeat center center !important;
background-size: contain !important; /* This makes sure the logo fits well in the header */
}}
#q-output {{
max-height: 60vh !important;
overflow: auto !important;
}}
"""
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(likeable=True, show_share_button=False, show_copy_button=True),
textbox=gr.Textbox(placeholder="stell mir Fragen", scale=7),
description="Ich bin Ihr hilfreicher KI-Assistent",
theme="soft",
submit_btn="Senden",
retry_btn="🔄 Wiederholen",
undo_btn="⏪ Rückgängig",
clear_btn="🗑️ Löschen",
examples=[
"Generate auditing questions about Change Management",
"Generate auditing questions about Software Maintenance",
"Generate auditing questions about Data Protection"
],
#cache_examples=True,
fill_height=True,
css=css,
).launch(show_api=False)