Chris4K's picture
Update app.py
35fc283 verified
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from urllib.parse import urlparse, urljoin
import requests
from bs4 import BeautifulSoup
app = FastAPI()
# Middleware to allow cross-origin communications
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# Function to crawl all URLs from a domain
def get_all_links_from_domain(domain_url):
print("domain url " + domain_url)
visited_urls = set()
domain_links = set()
parsed_initial_url = urlparse(domain_url)
base_domain = parsed_initial_url.netloc
get_links_from_page(domain_url, visited_urls, domain_links, domain_url)
return domain_links
# Function to crawl links from a page within the same domain
def get_links_from_page(url, visited_urls, all_links, base_domain):
print("url " + url)
print("base_domain " + base_domain)
if not url.startswith(base_domain):
return
if url in visited_urls:
return
visited_urls.add(url)
print("Getting next " + url)
response = requests.get(url)
print(response)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
base_url = urlparse(url).scheme + '://' + urlparse(url).netloc
links = soup.find_all('a', href=True)
for link in links:
href = link.get('href')
absolute_url = urljoin(base_url, href)
parsed_url = urlparse(absolute_url)
if absolute_url not in visited_urls:
if absolute_url.startswith(base_domain):
print("hrefe " +absolute_url)
all_links.add(absolute_url)
get_links_from_page(absolute_url, visited_urls, all_links, base_domain)
else:
print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain_community.document_transformers import BeautifulSoupTransformer
# Function to index URLs in RAG
def index_urls_in_rag(urls=[]):
# Load the RAG model
rag_model = "intfloat/e5-mistral-7b-instruct"
encode_kwargs = {
"normalize_embeddings": False
} # set True to compute cosine similarity
embeddings = HuggingFaceEmbeddings(
model_name=rag_model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
)
# Create a vector store for storing embeddings of documents
vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db", embedding_function=embeddings)
# print("Embedding " +urls)
for url in urls:
# Get text from the URL
loader = WebBaseLoader(url)
document = loader.load()
# Transform
bs_transformer = BeautifulSoupTransformer()
docs_transformed = bs_transformer.transform_documents(
document, class_to_extract=["p", "li", "div", "a"]
)
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter()
document_chunks = text_splitter.split_documents(docs_transformed)
print(document_chunks)
# Index document chunks into the vector store
vector_store.add_documents(document_chunks)
# Convert vector store to retriever
retriever = vector_store.as_retriever()
return retriever
# Function to load the RAG model
def load_model():
model = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
model_kwargs={"max_length": 1048, "temperature":0.1, "max_new_tokens":512, "top_p":0.95, "repetition_penalty":1.0},
)
return model
def get_conversational_rag_chain(retriever_chain):
llm = load_model()
prompt = ChatPromptTemplate.from_messages([
("system", "Du bist eine freundlicher Mitarbeiterin Namens Susie und arbeitest in einenm Call Center. Nutze immer und nur den CONTEXT für die Antwort auf folgende Frage. Antworte mit: 'Ich bin mir nicht sicher. Wollen Sie eine Mitarbeiter sprechen' Wenn die Antwort nicht aus dem Context hervorgeht. Antworte bitte immer auf Deutsch? CONTEXT:\n\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
def get_response(message, history=[]):
retriever_chain = index_urls_in_rag()
conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
response = conversation_rag_chain.invoke({
"chat_history": history,
"input": message + " Assistant: ",
"chat_message": message + " Assistant: "
})
#print("get_response " +response)
res = response['answer']
parts = res.split(" Assistant: ")
last_part = parts[-1]
return last_part#[-1]['generation']['content']
# Index URLs on app startup
@app.on_event("startup")
async def startup():
print("donee.... ")
# domain_url = 'https://www.bofrost.de/faq/service-infos-fuer-neukunden.html'
# links = get_all_links_from_domain(domain_url)
# print(links)
# retriever_chain = index_urls_in_rag(links)
# retriever_chain.invoke("Was ist bofrost*")
# get_response("Was kosten Schoko Osterhasen?")
# Define API endpoint to receive queries and provide responses
@app.post("/generate/")
def generate(user_input):
return get_response(user_input, [])
# Define API endpoint to receive queries and provide responses
@app.post("/update/")
def generate(index_url):
retriever_chain = index_urls_in_rag([index_url])
retriever_chain.invoke("Was ist bofrost*")
get_response("Was kosten Schoko Osterhasen?")