Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from urllib.parse import urlparse, urljoin | |
| import requests | |
| from bs4 import BeautifulSoup | |
| app = FastAPI() | |
| # Middleware to allow cross-origin communications | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| # Function to crawl all URLs from a domain | |
| def get_all_links_from_domain(domain_url): | |
| visited_urls = set() | |
| domain_links = set() | |
| parsed_initial_url = urlparse(domain_url) | |
| base_domain = parsed_initial_url.netloc | |
| get_links_from_page(domain_url, visited_urls, domain_links, base_domain) | |
| return domain_links | |
| # Function to crawl links from a page within the same domain | |
| def get_links_from_page(url, visited_urls, all_links, base_domain): | |
| if url in visited_urls: | |
| return | |
| visited_urls.add(url) | |
| print("Getting next" + url) | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| base_url = urlparse(url).scheme + '://' + urlparse(url).netloc | |
| links = soup.find_all('a', href=True) | |
| for link in links: | |
| href = link.get('href') | |
| absolute_url = urljoin(base_url, href) | |
| parsed_url = urlparse(absolute_url) | |
| if parsed_url.netloc == base_domain: | |
| all_links.add(absolute_url) | |
| get_links_from_page(absolute_url, visited_urls, all_links, base_domain) | |
| else: | |
| print(f"Failed to retrieve content from {url}. Status code: {response.status_code}") | |
| # Function to load the RAG model | |
| def load_rag_model(): | |
| model = HuggingFaceHub( | |
| repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0}, | |
| ) | |
| return model | |
| # Function to index URLs in RAG | |
| def index_urls_in_rag(urls): | |
| # Create a vector store for storing embeddings of documents | |
| vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db") | |
| # Load the RAG model | |
| rag_model = load_rag_model() | |
| for url in urls: | |
| # Get text from the URL | |
| loader = WebBaseLoader(url) | |
| document = loader.load() | |
| # Split the document into chunks | |
| text_splitter = RecursiveCharacterTextSplitter() | |
| document_chunks = text_splitter.split_documents(document) | |
| # Index document chunks into the vector store | |
| vector_store.add_documents(document_chunks) | |
| # Convert vector store to retriever | |
| retriever = vector_store.as_retriever() | |
| # Define prompt for RAG model | |
| prompt = ChatPromptTemplate.from_messages([ | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("user", "{input}"), | |
| ]) | |
| # Create history-aware retriever chain | |
| retriever_chain = create_history_aware_retriever(rag_model, retriever, prompt) | |
| return retriever_chain | |
| # Index URLs on app startup | |
| async def startup(): | |
| domain_url = 'https://www.bofrost.de/faq/' | |
| urls = get_all_links_from_domain(domain_url) | |
| retriever_chain = index_urls_in_rag(urls) | |
| # Define API endpoint to receive queries and provide responses | |
| def generate(user_input): | |
| response = retriever_chain.invoke({ | |
| "chat_history": [], | |
| "input": user_input | |
| }) | |
| return response['answer'] | |