beitrag-service / app.py
muhtasham's picture
Update app.py
a1b0c16 verified
raw
history blame
No virus
7.18 kB
import os
import tempfile
import gradio as gr
import torch
import logging
import base64
from operator import itemgetter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_community.vectorstores.chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_debug
from dotenv import load_dotenv
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
# configure logging
logging.basicConfig(level=logging.INFO)
set_debug(True)
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
langchain_api_key = os.getenv("LANGCHAIN_API_KEY")
langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT")
langchain_project_id = os.getenv("LANGCHAIN_PROJECT")
access_key = os.getenv("ACCESS_TOKEN_SECRET")
persist_dir = "./chroma_db"
device = 'cuda:0'
model_name = "all-mpnet-base-v2"
model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"}
logging.info(f"Using device {model_kwargs['device']}")
embed_money = False
# Create embeddings and store in vectordb
if embed_money:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
logging.info(f"Using OpenAI embeddings")
else:
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
logging.info(f"Using HuggingFace embeddings")
def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500):
logging.info("Configuring retriever")
if not os.path.exists(persist_dir):
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.")
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for filename in local_files:
logging.info(f"Reading file {filename}")
# Read the file once
if not os.path.exists(os.path.join("docs", filename)):
file_content = open(os.path.join(".", filename), "rb").read()
else:
file_content = open(os.path.join("docs", filename), "rb").read()
temp_filepath = os.path.join(temp_dir.name, filename)
with open(temp_filepath, "wb") as f:
f.write(file_content)
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
return retriever
else:
logging.info(f"Persist directory {persist_dir} exists. Loading from it.")
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
return retriever
directory = "docs" if os.path.exists("docs") else "."
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
def setup_llm(system_message):
# Setup LLM
llm = ChatOpenAI(
model_name="gpt-4o", openai_api_key=openai_api_key, temperature=0.1, streaming=True
)
retriever = configure_retriever(local_files)
template = system_message + """
Answer the question based only on the following context in it's original language.
{context}
Question: {question}
Original Message: {original_msg}
Chat History: {history}
If the question is not related to the context, answer with "I don't know" in the original language.
If the user is asking for follow-up questions on the same topic, generate different questions than you already answered.
If the user is asking to explain the context, or expand on the context, then provide explanation in the original language.
"""
prompt = ChatPromptTemplate.from_template(template)
chain_translate = (
llm
| StrOutputParser()
)
chain_rag = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
"original_msg": itemgetter("original_msg"),
"history": itemgetter("history")
}
| prompt
| llm
| StrOutputParser()
)
return chain_rag, chain_translate
def predict(message, history, system_message):
logging.info(system_message)
chain_rag, chain_translate = setup_llm(system_message)
message_transalated = chain_translate.invoke(f"Translate this query to English if it is in German otherwise return original contetn: {message}")
history_langchain_format = []
partial_message = ""
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))
for response in chain_rag.stream({"question": message_transalated, "original_msg": message, "history": history_langchain_format}):
partial_message += response
yield partial_message
image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png"
logo_base64 = image_to_base64(image_path)
# CSS with the Base64-encoded image
css = f"""
body::before {{
content: '';
display: block;
height: 150px !important; /* Adjust based on your logo's size */
background: url('data:image/png;base64,{logo_base64}') no-repeat center center !important;
background-size: contain !important; /* This makes sure the logo fits well in the header */
}}
#q-output {{
max-height: 60vh !important;
overflow: auto !important;
}}
"""
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(likeable=True, show_share_button=False, show_copy_button=True),
textbox=gr.Textbox(placeholder="stell mir Fragen", scale=7),
description="Ich bin Ihr hilfreicher KI-Assistent",
theme="soft",
submit_btn="Senden",
retry_btn="🔄 Wiederholen",
undo_btn="⏪ Rückgängig",
clear_btn="🗑️ Löschen",
additional_inputs=[
gr.Textbox("You are an auditor with many years of professional experience and are to develop a questionnaire on the topic of home office in the form of a self-assessment for me. As a basis for the questionnaire, you use standards and best practices (for example, from ISO 27001 and COBIT). The questionnaire should not exceed 20 questions.", label="System Prompt")
],
cache_examples=False,
fill_height=True,
css=css,
).launch(show_api=False)