Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import gradio as gr | |
import torch | |
import logging | |
from operator import itemgetter | |
from langchain_openai import ChatOpenAI | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.vectorstores.chroma import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import AIMessage, HumanMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains import create_retrieval_chain | |
from langchain.globals import set_debug | |
from dotenv import load_dotenv | |
# configure logging | |
logging.basicConfig(level=logging.INFO) | |
set_debug(True) | |
load_dotenv() | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
persist_dir = "./chroma_db" | |
device='cuda:0' | |
model_name="all-mpnet-base-v2" | |
model_kwargs = {'device': device if torch.cuda.is_available() else 'cpu'} | |
logging.info(f"Using device {model_kwargs['device']}") | |
# Create embeddings and store in vectordb | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs) | |
def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500): | |
logging.info("Configuring retriever") | |
if not os.path.exists(persist_dir): | |
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.") | |
# Read documents | |
docs = [] | |
temp_dir = tempfile.TemporaryDirectory() | |
for filename in local_files: | |
logging.info(f"Reading file {filename}") | |
# Read the file once | |
if not os.path.exists(os.path.join("docs", filename)): | |
file_content = open(os.path.join(".", filename), "rb").read() | |
else: | |
file_content = open(os.path.join("docs", filename), "rb").read() | |
temp_filepath = os.path.join(temp_dir.name, filename) | |
with open(temp_filepath, "wb") as f: | |
f.write(file_content) | |
loader = PyPDFLoader(temp_filepath) | |
docs.extend(loader.load()) | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
splits = text_splitter.split_documents(docs) | |
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir) | |
# Define retriever | |
retriever = vectordb.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={'score_threshold': 0.8} | |
) | |
return retriever | |
else: | |
logging.info(f"Persist directory {persist_dir} exists. Loading from it.") | |
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) | |
# Define retriever | |
retriever = vectordb.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={'score_threshold': 0.8} | |
) | |
return retriever | |
directory = "docs" if os.path.exists("docs") else "." | |
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")] | |
# Setup LLM | |
llm = ChatOpenAI( | |
model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0, streaming=True | |
) | |
retriever = configure_retriever(local_files) | |
template = """Answer the question based only on the following context: | |
{context} | |
Question: {question} | |
Answer in German language. | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
chain = ( | |
{ | |
"context": itemgetter("question") | retriever, | |
"question": itemgetter("question"), | |
} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
def predict(message, history): | |
message = f"Translate the following text to German: {message}" | |
history_langchain_format = [] | |
for human, ai in history: | |
history_langchain_format.append(HumanMessage(content=human)) | |
history_langchain_format.append(AIMessage(content=ai)) | |
history_langchain_format.append(HumanMessage(content=message)) | |
gpt_response = llm(history_langchain_format) | |
return chain.invoke({"question": gpt_response.content}) | |
demo = gr.ChatInterface( | |
predict, | |
chatbot=gr.Chatbot(height=500, show_share_button=True), | |
textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7), | |
title="Beitrag Service", | |
description="Ich bin Ihr hilfreicher KI-Assistent", | |
theme="soft", | |
examples=["Hello"], | |
cache_examples=True, | |
retry_btn="Wiederholen", | |
undo_btn="Vorheriges löschen", | |
clear_btn="Löschen").launch(show_api= False) | |
if __name__ == "__main__": | |
demo.launch() | |