Chris4K's picture
Update app.py
c97201a verified
raw
history blame
3.95 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from urllib.parse import urlparse, urljoin
import requests
from bs4 import BeautifulSoup
app = FastAPI()
# Middleware to allow cross-origin communications
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# Function to crawl all URLs from a domain
def get_all_links_from_domain(domain_url):
visited_urls = set()
domain_links = set()
parsed_initial_url = urlparse(domain_url)
base_domain = parsed_initial_url.netloc
get_links_from_page(domain_url, visited_urls, domain_links, base_domain)
return domain_links
# Function to crawl links from a page within the same domain
def get_links_from_page(url, visited_urls, all_links, base_domain):
if url in visited_urls:
return
visited_urls.add(url)
print("Getting next" + url)
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
base_url = urlparse(url).scheme + '://' + urlparse(url).netloc
links = soup.find_all('a', href=True)
for link in links:
href = link.get('href')
absolute_url = urljoin(base_url, href)
parsed_url = urlparse(absolute_url)
if parsed_url.netloc == base_domain:
all_links.add(absolute_url)
get_links_from_page(absolute_url, visited_urls, all_links, base_domain)
else:
print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
# Function to load the RAG model
def load_rag_model():
model = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
)
return model
# Function to index URLs in RAG
def index_urls_in_rag(urls):
# Create a vector store for storing embeddings of documents
vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db")
# Load the RAG model
rag_model = load_rag_model()
for url in urls:
# Get text from the URL
loader = WebBaseLoader(url)
document = loader.load()
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter()
document_chunks = text_splitter.split_documents(document)
# Index document chunks into the vector store
vector_store.add_documents(document_chunks)
# Convert vector store to retriever
retriever = vector_store.as_retriever()
# Define prompt for RAG model
prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
# Create history-aware retriever chain
retriever_chain = create_history_aware_retriever(rag_model, retriever, prompt)
return retriever_chain
# Index URLs on app startup
@app.on_event("startup")
async def startup():
domain_url = 'https://www.bofrost.de/faq/'
urls = get_all_links_from_domain(domain_url)
retriever_chain = index_urls_in_rag(urls)
# Define API endpoint to receive queries and provide responses
@app.post("/generate/")
def generate(user_input):
response = retriever_chain.invoke({
"chat_history": [],
"input": user_input
})
return response['answer']