Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain_community.llms import HuggingFaceHub | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from urllib.parse import urlparse, urljoin | |
import requests | |
from bs4 import BeautifulSoup | |
app = FastAPI() | |
# Middleware to allow cross-origin communications | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
# Function to crawl all URLs from a domain | |
def get_all_links_from_domain(domain_url): | |
visited_urls = set() | |
domain_links = set() | |
parsed_initial_url = urlparse(domain_url) | |
base_domain = parsed_initial_url.netloc | |
get_links_from_page(domain_url, visited_urls, domain_links, base_domain) | |
return domain_links | |
# Function to crawl links from a page within the same domain | |
def get_links_from_page(url, visited_urls, all_links, base_domain): | |
if not url.startswith(base_domain): | |
return | |
if url in visited_urls: | |
return | |
visited_urls.add(url) | |
print("Getting next" + url) | |
response = requests.get(url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.content, 'html.parser') | |
base_url = urlparse(url).scheme + '://' + urlparse(url).netloc | |
links = soup.find_all('a', href=True) | |
for link in links: | |
href = link.get('href') | |
absolute_url = urljoin(base_url, href) | |
parsed_url = urlparse(absolute_url) | |
if parsed_url.netloc == base_domain: | |
all_links.add(absolute_url) | |
get_links_from_page(absolute_url, visited_urls, all_links, base_domain) | |
else: | |
print(f"Failed to retrieve content from {url}. Status code: {response.status_code}") | |
# Function to load the RAG model | |
def load_rag_model(): | |
model = HuggingFaceHub( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0}, | |
) | |
return model | |
# Function to index URLs in RAG | |
def index_urls_in_rag(urls): | |
# Create a vector store for storing embeddings of documents | |
vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db") | |
# Load the RAG model | |
rag_model = load_rag_model() | |
for url in urls: | |
# Get text from the URL | |
loader = WebBaseLoader(url) | |
document = loader.load() | |
# Split the document into chunks | |
text_splitter = RecursiveCharacterTextSplitter() | |
document_chunks = text_splitter.split_documents(document) | |
# Index document chunks into the vector store | |
vector_store.add_documents(document_chunks) | |
# Convert vector store to retriever | |
retriever = vector_store.as_retriever() | |
# Define prompt for RAG model | |
prompt = ChatPromptTemplate.from_messages([ | |
MessagesPlaceholder(variable_name="chat_history"), | |
("user", "{input}"), | |
]) | |
# Create history-aware retriever chain | |
retriever_chain = create_history_aware_retriever(rag_model, retriever, prompt) | |
return retriever_chain | |
# Index URLs on app startup | |
async def startup(): | |
domain_url = 'https://www.bofrost.de/faq/' | |
urls = get_all_links_from_domain(domain_url) | |
retriever_chain = index_urls_in_rag(urls) | |
# Define API endpoint to receive queries and provide responses | |
def generate(user_input): | |
response = retriever_chain.invoke({ | |
"chat_history": [], | |
"input": user_input | |
}) | |
return response['answer'] | |