import os import tempfile import gradio as gr import torch import logging import base64 from operator import itemgetter from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_community.document_loaders import PyPDFLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_community.vectorstores.chroma import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import AIMessage, HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain.globals import set_debug from dotenv import load_dotenv def image_to_base64(image_path): with open(image_path, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return encoded_string # configure logging logging.basicConfig(level=logging.INFO) set_debug(True) load_dotenv() openai_api_key = os.getenv("OPENAI_API_KEY") langchain_api_key = os.getenv("LANGCHAIN_API_KEY") langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT") langchain_project_id = os.getenv("LANGCHAIN_PROJECT") access_key = os.getenv("ACCESS_TOKEN_SECRET") persist_dir = "./chroma_db" device = 'cuda:0' model_name = "all-mpnet-base-v2" model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"} logging.info(f"Using device {model_kwargs['device']}") embed_money = False # Create embeddings and store in vectordb if embed_money: embeddings = OpenAIEmbeddings(model="text-embedding-3-small") logging.info(f"Using OpenAI embeddings") else: embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs) logging.info(f"Using HuggingFace embeddings") def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500): logging.info("Configuring retriever") if not os.path.exists(persist_dir): logging.info(f"Persist directory {persist_dir} does not exist. Creating it.") # Read documents docs = [] temp_dir = tempfile.TemporaryDirectory() for filename in local_files: logging.info(f"Reading file {filename}") # Read the file once if not os.path.exists(os.path.join("docs", filename)): file_content = open(os.path.join(".", filename), "rb").read() else: file_content = open(os.path.join("docs", filename), "rb").read() temp_filepath = os.path.join(temp_dir.name, filename) with open(temp_filepath, "wb") as f: f.write(file_content) loader = PyPDFLoader(temp_filepath) docs.extend(loader.load()) # Split documents text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) splits = text_splitter.split_documents(docs) vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir) # Define retriever retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}) return retriever else: logging.info(f"Persist directory {persist_dir} exists. Loading from it.") vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) # Define retriever retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}) return retriever directory = "docs" if os.path.exists("docs") else "." local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")] # Setup LLM llm = ChatOpenAI( model_name="gpt-4-0125-preview", openai_api_key=openai_api_key, temperature=0.1, streaming=True ) retriever = configure_retriever(local_files) template = """Answer the question based only on the following context: {context} Question: {question} Chat History: {history} Answer in German Language. If the question is not related to the context, answer with "I don't know". If the user is asking for follow-up questions on the same topic, generate different questions than you already answered. """ prompt = ChatPromptTemplate.from_template(template) chain_translate = ( llm | StrOutputParser() ) chain_rag = ( { "context": itemgetter("question") | retriever, "question": itemgetter("question"), "history": itemgetter("history") } | prompt | llm | StrOutputParser() ) def predict(message, history): message = chain_translate.invoke(f"Translate this query to English if it is in German otherwise return original contetn: {message}") history_langchain_format = [] partial_message = "" for human, ai in history: history_langchain_format.append(HumanMessage(content=human)) history_langchain_format.append(AIMessage(content=ai)) history_langchain_format.append(HumanMessage(content=message)) for response in chain_rag.stream({"question": message, "history": history_langchain_format}): partial_message += response yield partial_message image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png" logo_base64 = image_to_base64(image_path) # CSS with the Base64-encoded image css = f""" body::before {{ content: ''; display: block; height: 150px !important; /* Adjust based on your logo's size */ background: url('data:image/png;base64,{logo_base64}') no-repeat center center !important; background-size: contain !important; /* This makes sure the logo fits well in the header */ }} #q-output {{ max-height: 60vh !important; overflow: auto !important; }} """ gr.ChatInterface( predict, chatbot=gr.Chatbot(likeable=True, show_share_button=False, show_copy_button=True), textbox=gr.Textbox(placeholder="stell mir Fragen", scale=7), description="Ich bin Ihr hilfreicher KI-Assistent", theme="soft", submit_btn="Senden", retry_btn="🔄 Wiederholen", undo_btn="⏪ Rückgängig", clear_btn="🗑️ Löschen", examples=[ "Generate auditing questions about Change Management", "Generate auditing questions about Software Maintenance", "Generate auditing questions about Data Protection" ], #cache_examples=True, fill_height=True, css=css, ).launch(show_api=False)