Chris4K's picture
Update app.py
f44a6d7 verified
raw
history blame
5.07 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from urllib.parse import urlparse, urljoin
import requests
from bs4 import BeautifulSoup
app = FastAPI()
# Middleware to allow cross-origin communications
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# Define retriever_chain as a global variable
retriever_chain = None
# Function to crawl all URLs from a domain
def get_all_links_from_domain(domain_url):
print("domain url " + domain_url)
visited_urls = set()
domain_links = set()
parsed_initial_url = urlparse(domain_url)
base_domain = parsed_initial_url.netloc
get_links_from_page(domain_url, visited_urls, domain_links, domain_url)
return domain_links
# Function to crawl links from a page within the same domain
def get_links_from_page(url, visited_urls, all_links, base_domain):
print("url " + url)
print("base_domain " + base_domain)
if not url.startswith(base_domain):
return
if url in visited_urls:
return
visited_urls.add(url)
print("Getting next" + url)
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
base_url = urlparse(url).scheme + '://' + urlparse(url).netloc
links = soup.find_all('a', href=True)
for link in links:
href = link.get('href')
absolute_url = urljoin(base_url, href)
parsed_url = urlparse(absolute_url)
if parsed_url.netloc == base_domain:
all_links.add(absolute_url)
get_links_from_page(absolute_url, visited_urls, all_links, base_domain)
else:
print(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
# Function to load the RAG model
def load_rag_model():
model = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
model_kwargs={"max_length": 1048, "temperature":0.2, "max_new_tokens":256, "top_p":0.95, "repetition_penalty":1.0},
)
return model
# Function to index URLs in RAG
def index_urls_in_rag(urls):
# Create a vector store for storing embeddings of documents
vector_store = Chroma(persist_directory="/home/user/.cache/chroma_db")
# Load the RAG model
rag_model = load_rag_model()
for url in urls:
# Get text from the URL
loader = WebBaseLoader(url)
document = loader.load()
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter()
document_chunks = text_splitter.split_documents(document)
# Index document chunks into the vector store
vector_store.add_documents(document_chunks)
# Convert vector store to retriever
retriever = vector_store.as_retriever()
# Define prompt for RAG model
prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
# Create history-aware retriever chain
retriever_chain = create_history_aware_retriever(rag_model, retriever, prompt)
return retriever_chain
# Index URLs on app startup
@app.on_event("startup")
async def startup():
domain_url = 'https://www.bofrost.de/faq/'
urls = get_all_links_from_domain(domain_url)
retriever_chain = index_urls_in_rag(urls)
# Define API endpoint to receive queries and provide responses
@app.post("/generate/")
def generate(user_input):
return get_response(user_input, []])
def get_response(message, history=[]):
dialog = history_to_dialog_format(history)
dialog.append({"role": "user", "content": message})
# Define the prompt as a ChatPromptValue object
#user_input = ChatPromptValue(user_input)
# Convert the prompt to a tensor
#input_ids = user_input.tensor
#vs = get_vectorstore_from_url(user_url, all_domain)
vs = get_vectorstore_from_url("https://huggingface.co/Chris4K")
history =[]
retriever_chain = get_context_retriever_chain(vs)
conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
response = conversation_rag_chain.invoke({
"chat_history": history,
"input": message + " Assistant: ",
"chat_message": message + " Assistant: "
})
#print("get_response " +response)
res = response['answer']
parts = res.split(" Assistant: ")
last_part = parts[-1]
return last_part#[-1]['generation']['content']