Spaces:
Sleeping
Sleeping
File size: 7,182 Bytes
6f7484c 281e223 6f7484c 2d02398 6f7484c 281e223 6f7484c 281e223 6f7484c 2d02398 6f7484c 2d02398 6f7484c 2d02398 6f7484c 2d02398 6f7484c 2d02398 6f7484c 2d02398 6f7484c a1b0c16 2d02398 a1b0c16 6f7484c a1b0c16 6f7484c a1b0c16 6f7484c a1b0c16 0aac03a a1b0c16 6f7484c a1b0c16 281e223 a1b0c16 6f7484c a1b0c16 281e223 a1b0c16 6f7484c a1b0c16 2d02398 a1b0c16 281e223 6f7484c 281e223 6f7484c a1b0c16 281e223 2d02398 281e223 6f7484c 281e223 0aac03a 281e223 a1b0c16 281e223 a1b0c16 281e223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
import tempfile
import gradio as gr
import torch
import logging
import base64
from operator import itemgetter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_community.vectorstores.chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_debug
from dotenv import load_dotenv
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
# configure logging
logging.basicConfig(level=logging.INFO)
set_debug(True)
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
langchain_api_key = os.getenv("LANGCHAIN_API_KEY")
langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT")
langchain_project_id = os.getenv("LANGCHAIN_PROJECT")
access_key = os.getenv("ACCESS_TOKEN_SECRET")
persist_dir = "./chroma_db"
device = 'cuda:0'
model_name = "all-mpnet-base-v2"
model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"}
logging.info(f"Using device {model_kwargs['device']}")
embed_money = False
# Create embeddings and store in vectordb
if embed_money:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
logging.info(f"Using OpenAI embeddings")
else:
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
logging.info(f"Using HuggingFace embeddings")
def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500):
logging.info("Configuring retriever")
if not os.path.exists(persist_dir):
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.")
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for filename in local_files:
logging.info(f"Reading file {filename}")
# Read the file once
if not os.path.exists(os.path.join("docs", filename)):
file_content = open(os.path.join(".", filename), "rb").read()
else:
file_content = open(os.path.join("docs", filename), "rb").read()
temp_filepath = os.path.join(temp_dir.name, filename)
with open(temp_filepath, "wb") as f:
f.write(file_content)
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
return retriever
else:
logging.info(f"Persist directory {persist_dir} exists. Loading from it.")
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
return retriever
directory = "docs" if os.path.exists("docs") else "."
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
def setup_llm(system_message):
# Setup LLM
llm = ChatOpenAI(
model_name="gpt-4o", openai_api_key=openai_api_key, temperature=0.1, streaming=True
)
retriever = configure_retriever(local_files)
template = system_message + """
Answer the question based only on the following context in it's original language.
{context}
Question: {question}
Original Message: {original_msg}
Chat History: {history}
If the question is not related to the context, answer with "I don't know" in the original language.
If the user is asking for follow-up questions on the same topic, generate different questions than you already answered.
If the user is asking to explain the context, or expand on the context, then provide explanation in the original language.
"""
prompt = ChatPromptTemplate.from_template(template)
chain_translate = (
llm
| StrOutputParser()
)
chain_rag = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
"original_msg": itemgetter("original_msg"),
"history": itemgetter("history")
}
| prompt
| llm
| StrOutputParser()
)
return chain_rag, chain_translate
def predict(message, history, system_message):
logging.info(system_message)
chain_rag, chain_translate = setup_llm(system_message)
message_transalated = chain_translate.invoke(f"Translate this query to English if it is in German otherwise return original contetn: {message}")
history_langchain_format = []
partial_message = ""
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))
for response in chain_rag.stream({"question": message_transalated, "original_msg": message, "history": history_langchain_format}):
partial_message += response
yield partial_message
image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png"
logo_base64 = image_to_base64(image_path)
# CSS with the Base64-encoded image
css = f"""
body::before {{
content: '';
display: block;
height: 150px !important; /* Adjust based on your logo's size */
background: url('data:image/png;base64,{logo_base64}') no-repeat center center !important;
background-size: contain !important; /* This makes sure the logo fits well in the header */
}}
#q-output {{
max-height: 60vh !important;
overflow: auto !important;
}}
"""
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(likeable=True, show_share_button=False, show_copy_button=True),
textbox=gr.Textbox(placeholder="stell mir Fragen", scale=7),
description="Ich bin Ihr hilfreicher KI-Assistent",
theme="soft",
submit_btn="Senden",
retry_btn="🔄 Wiederholen",
undo_btn="⏪ Rückgängig",
clear_btn="🗑️ Löschen",
additional_inputs=[
gr.Textbox("You are an auditor with many years of professional experience and are to develop a questionnaire on the topic of home office in the form of a self-assessment for me. As a basis for the questionnaire, you use standards and best practices (for example, from ISO 27001 and COBIT). The questionnaire should not exceed 20 questions.", label="System Prompt")
],
cache_examples=False,
fill_height=True,
css=css,
).launch(show_api=False) |