Spaces:
Sleeping
Sleeping
Commit
·
9f84bcd
1
Parent(s):
01d3dd8
Backend added
Browse files- .gitignore +9 -0
- Dockerfile +50 -0
- README.md +92 -0
- chatbot.py +178 -0
- main.py +190 -0
- requirements.txt +158 -0
- worker.py +154 -0
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv
|
| 2 |
+
cache
|
| 3 |
+
model
|
| 4 |
+
__pycache__
|
| 5 |
+
venv
|
| 6 |
+
t.py
|
| 7 |
+
temp.py
|
| 8 |
+
.env
|
| 9 |
+
app.py
|
Dockerfile
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use a Python base image compatible with Hugging Face Spaces
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory inside the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Prevent Python from writing pyc files and buffering stdout
|
| 8 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 9 |
+
ENV PYTHONUNBUFFERED=1
|
| 10 |
+
|
| 11 |
+
# Install system dependencies required for Playwright + Crawl4AI
|
| 12 |
+
RUN apt-get update && apt-get install -y \
|
| 13 |
+
curl \
|
| 14 |
+
wget \
|
| 15 |
+
unzip \
|
| 16 |
+
git \
|
| 17 |
+
xvfb \
|
| 18 |
+
libnss3 \
|
| 19 |
+
libatk-bridge2.0-0 \
|
| 20 |
+
libx11-xcb1 \
|
| 21 |
+
libxcomposite1 \
|
| 22 |
+
libxdamage1 \
|
| 23 |
+
libxrandr2 \
|
| 24 |
+
libgbm-dev \
|
| 25 |
+
libasound2 \
|
| 26 |
+
libatk1.0-0 \
|
| 27 |
+
libxkbcommon0 \
|
| 28 |
+
libcups2 \
|
| 29 |
+
libgtk-3-0 \
|
| 30 |
+
fonts-liberation \
|
| 31 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 32 |
+
|
| 33 |
+
# Copy requirements and install Python dependencies
|
| 34 |
+
COPY requirements.txt .
|
| 35 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 36 |
+
|
| 37 |
+
# Install Playwright Chromium (used by Crawl4AI)
|
| 38 |
+
RUN playwright install --with-deps chromium
|
| 39 |
+
|
| 40 |
+
# Copy the entire app (including .env)
|
| 41 |
+
COPY . .
|
| 42 |
+
|
| 43 |
+
# Expose the port expected by Hugging Face (7860)
|
| 44 |
+
EXPOSE 7860
|
| 45 |
+
|
| 46 |
+
# Hugging Face expects the app to listen on port 7860
|
| 47 |
+
ENV PORT=7860
|
| 48 |
+
|
| 49 |
+
# Command to run FastAPI with uvicorn
|
| 50 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -8,3 +8,95 @@ pinned: false
|
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 11 |
+
|
| 12 |
+
# WebIQ – Boosts your web intelligence with AI-powered insights
|
| 13 |
+
|
| 14 |
+
## Overview
|
| 15 |
+
WebIQ is a powerful **web scraping** and **question-answering (QA)** chatbot that follows the **Retrieval-Augmented Generation (RAG)** pipeline. It extracts and retrieves key insights from any website and generates **AI-powered** responses based on the extracted data. WebIQ leverages **FAISS** for efficient similarity search, **LangChain** for retrieval orchestration, and state-of-the-art **LLMs** for response generation.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
- **Automated Web Scraping**: Extracts text data from webpages, caches it locally, and supports both targeted and full-site scraping.
|
| 19 |
+
- **Vector Embeddings**: Uses FAISS to store and retrieve information efficiently.
|
| 20 |
+
- **LLM Integration**: Supports OpenAI (GPT-4) and Hugging Face (Llama-2, Mistral, etc.).
|
| 21 |
+
- **Chunking for Optimization**: Splits documents into meaningful chunks to enhance retrieval quality.
|
| 22 |
+
- **Asynchronous Processing**: Uses `asyncio` for efficient execution.
|
| 23 |
+
- **Caching Mechanism**: Ensures previously processed webpages are not reprocessed.
|
| 24 |
+
- **Batch Processing**: Processes large numbers of URLs efficiently.
|
| 25 |
+
- **Memory Usage Logging**: Tracks memory consumption before and after each batch for efficiency monitoring.
|
| 26 |
+
- **Multi-Page Scraping**: Seamlessly scrapes content from multiple webpages and aggregates insights.
|
| 27 |
+
|
| 28 |
+
## Installation
|
| 29 |
+
|
| 30 |
+
1. Clone the repository:
|
| 31 |
+
```sh
|
| 32 |
+
git clone https://github.com/Siddharth-Chandel/WebIQ.git
|
| 33 |
+
cd WebIQ
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
2. Create a virtual environment and activate it:
|
| 37 |
+
```sh
|
| 38 |
+
python -m venv venv
|
| 39 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
3. Install the required dependencies:
|
| 43 |
+
```sh
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
4. Set up environment variables by creating a `.env` file:
|
| 48 |
+
```sh
|
| 49 |
+
HUGGINGFACEHUB_API_TOKEN=your_huggingface_token
|
| 50 |
+
OPENAI_API_KEY=your_openai_api_key # If using OpenAI
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Usage
|
| 54 |
+
|
| 55 |
+
1. Run the chatbot script:
|
| 56 |
+
```sh
|
| 57 |
+
python chatbot.py
|
| 58 |
+
```
|
| 59 |
+
2. Enter a **URL** when prompted (e.g., `https://playwright.dev`).
|
| 60 |
+
3. Enter your **query** (e.g., `Describe Playwright and its benefits`).
|
| 61 |
+
4. The chatbot will scrape the webpage, process the data, and return an AI-generated response.
|
| 62 |
+
|
| 63 |
+
## Example Output
|
| 64 |
+
```
|
| 65 |
+
====================* Answer *====================
|
| 66 |
+
Playwright is an end-to-end testing framework that provides...
|
| 67 |
+
|
| 68 |
+
=================* Source Documents *=================
|
| 69 |
+
Source 1:
|
| 70 |
+
file: cache/playwright-dev/pages/page_1.txt
|
| 71 |
+
Content: Playwright is a Node.js library that automates browsers.
|
| 72 |
+
```
|
| 73 |
+
## Practical Use Cases
|
| 74 |
+
- **Research Assistance**: Quickly extract and summarize information from research papers, blogs, or documentation.
|
| 75 |
+
- **Competitive Analysis**: Monitor competitors' websites and extract relevant insights for business strategy.
|
| 76 |
+
- **Customer Support**: Enhance chatbot capabilities by integrating real-time website data retrieval.
|
| 77 |
+
- **Market Intelligence**: Gather structured data from news sites, product pages, or financial reports for analysis.
|
| 78 |
+
- **SEO Optimization**: Analyze webpage content for better keyword targeting and content strategy.
|
| 79 |
+
|
| 80 |
+
## Technologies Used
|
| 81 |
+
- **RAG** (Providing better context)
|
| 82 |
+
- **LangChain** (Retrieval-based QA system)
|
| 83 |
+
- **FAISS** (Efficient similarity search)
|
| 84 |
+
- **Hugging Face Transformers** (LLMs & embeddings)
|
| 85 |
+
- **OpenAI GPT-4** (Optional for LLM-based response generation)
|
| 86 |
+
- **Crawl4AI** (An LLM-based web-scraper)
|
| 87 |
+
- **AsyncIO** (Increment the processing speed)
|
| 88 |
+
- **Rich** (For colorful CLI outputs)
|
| 89 |
+
|
| 90 |
+
## Future Enhancements
|
| 91 |
+
- Develop an interactive web UI using Streamlit or FastAPI for a seamless user experience.
|
| 92 |
+
- Enhance retrieval quality with advanced RAG tuning and improved embeddings.
|
| 93 |
+
|
| 94 |
+
## License
|
| 95 |
+
This project is licensed under the **MIT License**.
|
| 96 |
+
|
| 97 |
+
## Author
|
| 98 |
+
Siddharth Chandel - Developed as part of NLP & AI research.
|
| 99 |
+
Let's connect on [LinkedIn](https://www.linkedin.com/in/siddharth-chandel-001097245/) !!!
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
_Contributions are welcome! Feel free to fork and enhance._ 🚀
|
chatbot.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
|
| 3 |
+
load_dotenv()
|
| 4 |
+
import os
|
| 5 |
+
import asyncio
|
| 6 |
+
import logging
|
| 7 |
+
from langchain_community.document_loaders import TextLoader
|
| 8 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 9 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 10 |
+
from langchain_community.vectorstores import FAISS
|
| 11 |
+
from langchain_openai import ChatOpenAI
|
| 12 |
+
from langchain_community.llms import CTransformers
|
| 13 |
+
from langchain_core.prompts import PromptTemplate
|
| 14 |
+
from transformers import pipeline
|
| 15 |
+
from langchain_huggingface import HuggingFacePipeline
|
| 16 |
+
from rich import print as rprint
|
| 17 |
+
from worker import scrape_website
|
| 18 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 19 |
+
|
| 20 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 21 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 22 |
+
|
| 23 |
+
DEFAULT_MODEL = "TheBloke/Llama-2-7B-Chat-GGML"
|
| 24 |
+
EMBEDDING_MODEL = "BAAI/bge-small-en"
|
| 25 |
+
|
| 26 |
+
# -------------------- Document Preparation --------------------
|
| 27 |
+
async def prepare_document(url: str | list[str]):
|
| 28 |
+
if isinstance(url, str):
|
| 29 |
+
folder = f"{url[8:].replace('.', '-').split('/')[0]}"
|
| 30 |
+
cache_path = os.path.join("cache", folder, "pages")
|
| 31 |
+
else:
|
| 32 |
+
folder = f"{url[0][8:].replace('.', '-').split('/')[0]}"
|
| 33 |
+
cache_path = os.path.join("cache", f"list_{folder}", "pages")
|
| 34 |
+
|
| 35 |
+
os.makedirs(cache_path, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
if not os.path.exists(f"{cache_path}/page_1.txt"):
|
| 38 |
+
logging.info("Document not found. Scraping website...")
|
| 39 |
+
await scrape_website(url, cache_path)
|
| 40 |
+
logging.info("Scraping completed.")
|
| 41 |
+
|
| 42 |
+
return cache_path
|
| 43 |
+
|
| 44 |
+
# -------------------- Embedding --------------------
|
| 45 |
+
def get_embedding_model(embedding_model_name="", api_key=""):
|
| 46 |
+
# Use OpenAI if api_key provided or model name indicates OpenAI
|
| 47 |
+
if api_key or "openai" in embedding_model_name.lower():
|
| 48 |
+
if not api_key:
|
| 49 |
+
raise ValueError("OpenAI API key required for OpenAI embeddings")
|
| 50 |
+
from langchain_openai import OpenAIEmbeddings
|
| 51 |
+
return OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key)
|
| 52 |
+
|
| 53 |
+
# Use HuggingFace otherwise
|
| 54 |
+
else:
|
| 55 |
+
# Ensure HF token is set in env for this thread
|
| 56 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 57 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
|
| 58 |
+
|
| 59 |
+
return HuggingFaceEmbeddings(model_name=embedding_model_name or EMBEDDING_MODEL)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# -------------------- Process & Build Vector Store --------------------
|
| 63 |
+
def process_documents(file_path: str, embedding_model, chunk_size=500, chunk_overlap=100):
|
| 64 |
+
try:
|
| 65 |
+
cache_path = os.path.dirname(file_path)
|
| 66 |
+
faiss_path = f"{cache_path}/faiss_index_store"
|
| 67 |
+
|
| 68 |
+
if os.path.exists(faiss_path):
|
| 69 |
+
logging.info("FAISS index exists. Skipping rebuild.")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
documents = []
|
| 73 |
+
for file in os.listdir(f"{cache_path}/pages"):
|
| 74 |
+
doc_loader = TextLoader(os.path.join(cache_path, "pages", file), encoding="utf-8")
|
| 75 |
+
documents.extend(doc_loader.load())
|
| 76 |
+
|
| 77 |
+
logging.info(f"Loaded {len(documents)} pages")
|
| 78 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 79 |
+
chunks = text_splitter.split_documents(documents)
|
| 80 |
+
|
| 81 |
+
vector_db = FAISS.from_documents(chunks, embedding_model)
|
| 82 |
+
vector_db.save_local(faiss_path)
|
| 83 |
+
logging.info("FAISS store saved successfully")
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logging.error(f"Error in document processing: {e}")
|
| 87 |
+
|
| 88 |
+
# -------------------- Load Retriever --------------------
|
| 89 |
+
async def load_retriever(file_path: str, embedding_model_name="", api_key=""):
|
| 90 |
+
cache_path = os.path.dirname(file_path)
|
| 91 |
+
embedding_model = get_embedding_model(embedding_model_name, api_key)
|
| 92 |
+
faiss_path = f"{cache_path}/faiss_index_store"
|
| 93 |
+
|
| 94 |
+
if not os.path.exists(faiss_path):
|
| 95 |
+
logging.warning("FAISS index missing. Rebuilding...")
|
| 96 |
+
process_documents(file_path, embedding_model)
|
| 97 |
+
|
| 98 |
+
vector_db = FAISS.load_local(faiss_path, embedding_model, allow_dangerous_deserialization=True)
|
| 99 |
+
return vector_db.as_retriever(search_kwargs={"k": 3})
|
| 100 |
+
|
| 101 |
+
# -------------------- Build Custom QA Pipeline --------------------
|
| 102 |
+
async def build_pipeline(url: str | list, llm_model="", embedding_model="", api_key=""):
|
| 103 |
+
# Force default model if llm_model is empty or 'default'
|
| 104 |
+
if not llm_model or llm_model.lower() == "default":
|
| 105 |
+
llm_model = DEFAULT_MODEL
|
| 106 |
+
logging.info(f"[LLM] Using model: {llm_model}")
|
| 107 |
+
|
| 108 |
+
file_path = await prepare_document(url)
|
| 109 |
+
retriever = await load_retriever(file_path, embedding_model, api_key)
|
| 110 |
+
|
| 111 |
+
llm_model_lower = llm_model.lower()
|
| 112 |
+
# OpenAI LLM
|
| 113 |
+
if "openai" in llm_model_lower:
|
| 114 |
+
llm = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=api_key)
|
| 115 |
+
# GGML model
|
| 116 |
+
elif llm_model_lower.endswith("-ggml"):
|
| 117 |
+
llm = CTransformers(model=llm_model, model_type="llama", config={"context_length": 4096})
|
| 118 |
+
# Hugging Face PyTorch model
|
| 119 |
+
else:
|
| 120 |
+
try:
|
| 121 |
+
hf_pipeline = pipeline(
|
| 122 |
+
"text-generation",
|
| 123 |
+
model=llm_model,
|
| 124 |
+
use_auth_token=HUGGINGFACEHUB_API_TOKEN
|
| 125 |
+
)
|
| 126 |
+
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logging.error(f"Failed to load Hugging Face model '{llm_model}'. Error: {e}")
|
| 129 |
+
raise RuntimeError(f"Cannot load Hugging Face model: {e}")
|
| 130 |
+
|
| 131 |
+
prompt = PromptTemplate(
|
| 132 |
+
input_variables=["context", "question"],
|
| 133 |
+
template="You are a helpful assistant. Use the following context to answer.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return llm, retriever, prompt
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Chatbot:
|
| 140 |
+
def __init__(self, url: str | list, llm_model="", embedding_model="", api_key=""):
|
| 141 |
+
self.url = url
|
| 142 |
+
self.llm_model = llm_model
|
| 143 |
+
self.embedding_model = embedding_model
|
| 144 |
+
self.api_key = api_key
|
| 145 |
+
|
| 146 |
+
async def initialize(self):
|
| 147 |
+
self.llm, self.retriever, self.prompt = await build_pipeline(
|
| 148 |
+
self.url, self.llm_model, self.embedding_model, self.api_key
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
async def query(self, question: str):
|
| 152 |
+
# Use async method if available
|
| 153 |
+
if hasattr(self.retriever, "aretrieve"):
|
| 154 |
+
docs = await self.retriever.aretrieve(question)
|
| 155 |
+
else:
|
| 156 |
+
# fallback: call the private method with run_manager=None
|
| 157 |
+
docs = await asyncio.to_thread(self.retriever._get_relevant_documents, question, run_manager=None)
|
| 158 |
+
|
| 159 |
+
context = "\n\n".join([d.page_content for d in docs])
|
| 160 |
+
prompt_text = self.prompt.format(context=context, question=question)
|
| 161 |
+
response = await asyncio.to_thread(self.llm.invoke, prompt_text)
|
| 162 |
+
return response
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# -------------------- Example Runner --------------------
|
| 167 |
+
async def main():
|
| 168 |
+
url = input("Enter URL: ").strip()
|
| 169 |
+
query = input("Enter your question: ").strip()
|
| 170 |
+
|
| 171 |
+
bot = Chatbot([url])
|
| 172 |
+
await bot.initialize()
|
| 173 |
+
answer = await bot.query(query)
|
| 174 |
+
rprint(f"\n[bold cyan]=== Answer ===[/bold cyan]\n{answer}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
asyncio.run(main())
|
main.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
load_dotenv()
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import uuid
|
| 7 |
+
import asyncio
|
| 8 |
+
import logging
|
| 9 |
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from chatbot import Chatbot
|
| 12 |
+
|
| 13 |
+
# -----------------------------
|
| 14 |
+
# Windows Asyncio Fix
|
| 15 |
+
# -----------------------------
|
| 16 |
+
if sys.platform == "win32":
|
| 17 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 18 |
+
|
| 19 |
+
# -----------------------------
|
| 20 |
+
# FastAPI app & CORS
|
| 21 |
+
# -----------------------------
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title="Session-Based RAG Chatbot API",
|
| 24 |
+
description="Session-based RAG Chatbot API with WebSocket support",
|
| 25 |
+
version="1.1.0"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
origins = [
|
| 29 |
+
"http://localhost:8080",
|
| 30 |
+
"http://127.0.0.1:8080",
|
| 31 |
+
"http://127.0.0.1:5500",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
app.add_middleware(
|
| 35 |
+
CORSMiddleware,
|
| 36 |
+
allow_origins=origins,
|
| 37 |
+
allow_credentials=True,
|
| 38 |
+
allow_methods=["*"],
|
| 39 |
+
allow_headers=["*"],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# -----------------------------
|
| 43 |
+
# Session storage
|
| 44 |
+
# -----------------------------
|
| 45 |
+
chatbot_sessions = {} # {session_id: Chatbot instance or None if failed}
|
| 46 |
+
|
| 47 |
+
# -----------------------------
|
| 48 |
+
# Root endpoint
|
| 49 |
+
# -----------------------------
|
| 50 |
+
@app.get("/")
|
| 51 |
+
def read_root():
|
| 52 |
+
return {"message": "Welcome to the Session-Based RAG Chatbot API!", "status": "Ready"}
|
| 53 |
+
|
| 54 |
+
@app.get("/create_session")
|
| 55 |
+
def create_session():
|
| 56 |
+
return {"session":str(uuid.uuid4())}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@app.get("/session_status/{session_id}")
|
| 60 |
+
def session_status(session_id: str):
|
| 61 |
+
"""
|
| 62 |
+
Returns the current status of a chatbot session.
|
| 63 |
+
Status can be:
|
| 64 |
+
- initializing (session exists but chatbot not ready)
|
| 65 |
+
- ready (chatbot instance ready)
|
| 66 |
+
- failed (chatbot initialization failed)
|
| 67 |
+
"""
|
| 68 |
+
if session_id not in chatbot_sessions:
|
| 69 |
+
return {"status": "not_found"}
|
| 70 |
+
|
| 71 |
+
chatbot = chatbot_sessions[session_id]
|
| 72 |
+
if chatbot is None:
|
| 73 |
+
return {"status": "initializing"}
|
| 74 |
+
elif chatbot == "err":
|
| 75 |
+
chatbot = None
|
| 76 |
+
return {"status": "failed"}
|
| 77 |
+
|
| 78 |
+
return {"status": "ready"}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# -----------------------------
|
| 83 |
+
# Helper: Run async init in background
|
| 84 |
+
# -----------------------------
|
| 85 |
+
def run_chatbot_init(session_id, urls, llm_model, embedding_model, api_key):
|
| 86 |
+
asyncio.create_task(initialize_chatbot(session_id, urls, llm_model, embedding_model, api_key))
|
| 87 |
+
|
| 88 |
+
# -----------------------------
|
| 89 |
+
# Scrape & initialize chatbot
|
| 90 |
+
# -----------------------------
|
| 91 |
+
@app.post("/scrape/")
|
| 92 |
+
async def scrape_and_load(response: dict, background_tasks: BackgroundTasks):
|
| 93 |
+
session_id = response.get("session_id")
|
| 94 |
+
urls = response.get("urls")
|
| 95 |
+
llm_model = response.get("llm_model", "TheBloke/Llama-2-7B-Chat-GGML")
|
| 96 |
+
embedding_model = response.get("embedding_model", "BAAI/bge-small-en")
|
| 97 |
+
api_key = response.get("api_key", None)
|
| 98 |
+
|
| 99 |
+
if not urls:
|
| 100 |
+
raise HTTPException(status_code=400, detail="urls are required.")
|
| 101 |
+
|
| 102 |
+
if session_id in chatbot_sessions:
|
| 103 |
+
return {"message": f"Chatbot for session {session_id} already initialized.", "session_id": session_id}
|
| 104 |
+
|
| 105 |
+
# Mark session as initializing
|
| 106 |
+
chatbot_sessions[session_id] = None
|
| 107 |
+
|
| 108 |
+
# Use a **blocking wrapper** to run async in thread safely
|
| 109 |
+
async def init_wrapper():
|
| 110 |
+
try:
|
| 111 |
+
await initialize_chatbot(session_id, urls, llm_model, embedding_model, api_key)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logging.error(f"[{session_id}] Initialization error: {e}", exc_info=True)
|
| 114 |
+
chatbot_sessions[session_id] = None
|
| 115 |
+
|
| 116 |
+
background_tasks.add_task(init_wrapper)
|
| 117 |
+
|
| 118 |
+
logging.info(f"[{session_id}] Chatbot initialization scheduled in background.")
|
| 119 |
+
return {"message": "Chatbot initialization started.", "session_id": session_id}
|
| 120 |
+
|
| 121 |
+
# -----------------------------
|
| 122 |
+
# Initialize chatbot
|
| 123 |
+
# -----------------------------
|
| 124 |
+
async def initialize_chatbot(session_id, urls, llm_model, embedding_model, api_key):
|
| 125 |
+
try:
|
| 126 |
+
logging.info(f"[{session_id}] Initializing chatbot...")
|
| 127 |
+
chatbot = Chatbot(
|
| 128 |
+
url=urls,
|
| 129 |
+
llm_model=llm_model,
|
| 130 |
+
embedding_model=embedding_model,
|
| 131 |
+
api_key=api_key
|
| 132 |
+
)
|
| 133 |
+
await chatbot.initialize()
|
| 134 |
+
|
| 135 |
+
chatbot_sessions[session_id] = chatbot
|
| 136 |
+
logging.info(f"[{session_id}] Chatbot ready.")
|
| 137 |
+
except NotImplementedError as e:
|
| 138 |
+
logging.error(f"[{session_id}] Playwright async not supported on Windows: {e}", exc_info=True)
|
| 139 |
+
chatbot_sessions[session_id] = None
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logging.error(f"[{session_id}] Initialization failed: {e}", exc_info=True)
|
| 142 |
+
chatbot_sessions[session_id] = "err"
|
| 143 |
+
|
| 144 |
+
# -----------------------------
|
| 145 |
+
# WebSocket endpoint
|
| 146 |
+
# -----------------------------
|
| 147 |
+
@app.websocket("/ws/chat/{session_id}")
|
| 148 |
+
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
| 149 |
+
await websocket.accept()
|
| 150 |
+
logging.info(f"[{session_id}] WebSocket connected.")
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
# Wait until chatbot is ready
|
| 154 |
+
while session_id not in chatbot_sessions or chatbot_sessions[session_id] is None:
|
| 155 |
+
await websocket.send_json({"text": "Initializing chatbot, please wait..."})
|
| 156 |
+
await asyncio.sleep(1)
|
| 157 |
+
|
| 158 |
+
chatbot_instance = chatbot_sessions[session_id]
|
| 159 |
+
if chatbot_instance is None:
|
| 160 |
+
await websocket.send_json({
|
| 161 |
+
"text": "Chatbot initialization failed. Likely due to Playwright async issue on Windows."
|
| 162 |
+
})
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
await websocket.send_json({"text": f"Chatbot session {session_id} is ready! You can start chatting."})
|
| 166 |
+
|
| 167 |
+
while True:
|
| 168 |
+
data = await websocket.receive_json()
|
| 169 |
+
query = data.get("query")
|
| 170 |
+
if not query:
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
response_text = await chatbot_instance.query(query)
|
| 174 |
+
await websocket.send_json({"text": response_text})
|
| 175 |
+
|
| 176 |
+
except WebSocketDisconnect:
|
| 177 |
+
logging.info(f"[{session_id}] WebSocket disconnected.")
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logging.error(f"[{session_id}] WebSocket error: {e}", exc_info=True)
|
| 180 |
+
try:
|
| 181 |
+
await websocket.send_json({"text": "An unexpected server error occurred."})
|
| 182 |
+
except:
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
# -----------------------------
|
| 186 |
+
# Run with: uvicorn main:app --reload
|
| 187 |
+
# -----------------------------
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
import uvicorn
|
| 190 |
+
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.11.0
|
| 2 |
+
aiofiles==25.1.0
|
| 3 |
+
aiohappyeyeballs==2.6.1
|
| 4 |
+
aiohttp==3.13.1
|
| 5 |
+
aiosignal==1.4.0
|
| 6 |
+
aiosqlite==0.21.0
|
| 7 |
+
alphashape==1.3.1
|
| 8 |
+
altair==5.5.0
|
| 9 |
+
annotated-types==0.7.0
|
| 10 |
+
anyio==4.11.0
|
| 11 |
+
attrs==25.4.0
|
| 12 |
+
beautifulsoup4==4.14.2
|
| 13 |
+
bitsandbytes==0.42.0
|
| 14 |
+
blinker==1.9.0
|
| 15 |
+
Brotli==1.1.0
|
| 16 |
+
cachetools==6.2.1
|
| 17 |
+
certifi==2025.10.5
|
| 18 |
+
cffi==2.0.0
|
| 19 |
+
chardet==5.2.0
|
| 20 |
+
charset-normalizer==3.4.4
|
| 21 |
+
click==8.3.0
|
| 22 |
+
click-log==0.4.0
|
| 23 |
+
Crawl4AI==0.7.6
|
| 24 |
+
cryptography==46.0.3
|
| 25 |
+
cssselect==1.3.0
|
| 26 |
+
ctransformers==0.2.27
|
| 27 |
+
dataclasses-json==0.6.7
|
| 28 |
+
distro==1.9.0
|
| 29 |
+
dotenv==0.9.9
|
| 30 |
+
faiss-cpu==1.12.0
|
| 31 |
+
fake-http-header==0.3.5
|
| 32 |
+
fake-useragent==2.2.0
|
| 33 |
+
fastapi==0.119.1
|
| 34 |
+
fastuuid==0.14.0
|
| 35 |
+
filelock==3.20.0
|
| 36 |
+
frozenlist==1.8.0
|
| 37 |
+
fsspec==2025.9.0
|
| 38 |
+
gitdb==4.0.12
|
| 39 |
+
GitPython==3.1.45
|
| 40 |
+
greenlet==3.2.4
|
| 41 |
+
h11==0.16.0
|
| 42 |
+
h2==4.3.0
|
| 43 |
+
hf-xet==1.1.10
|
| 44 |
+
hpack==4.1.0
|
| 45 |
+
httpcore==1.0.9
|
| 46 |
+
httptools==0.7.1
|
| 47 |
+
httpx==0.28.1
|
| 48 |
+
httpx-sse==0.4.3
|
| 49 |
+
huggingface-hub==0.35.3
|
| 50 |
+
humanize==4.14.0
|
| 51 |
+
hyperframe==6.1.0
|
| 52 |
+
idna==3.11
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
Jinja2==3.1.6
|
| 55 |
+
jiter==0.11.1
|
| 56 |
+
joblib==1.5.2
|
| 57 |
+
jsonpatch==1.33
|
| 58 |
+
jsonpointer==3.0.0
|
| 59 |
+
jsonschema==4.25.1
|
| 60 |
+
jsonschema-specifications==2025.9.1
|
| 61 |
+
langchain==1.0.2
|
| 62 |
+
langchain-classic==1.0.0
|
| 63 |
+
langchain-community==0.4
|
| 64 |
+
langchain-core==1.0.0
|
| 65 |
+
langchain-huggingface==1.0.0
|
| 66 |
+
langchain-openai==1.0.1
|
| 67 |
+
langchain-text-splitters==1.0.0
|
| 68 |
+
langgraph==1.0.1
|
| 69 |
+
langgraph-checkpoint==3.0.0
|
| 70 |
+
langgraph-prebuilt==1.0.1
|
| 71 |
+
langgraph-sdk==0.2.9
|
| 72 |
+
langsmith==0.4.37
|
| 73 |
+
lark==1.3.0
|
| 74 |
+
litellm==1.78.6
|
| 75 |
+
lxml==5.4.0
|
| 76 |
+
markdown-it-py==4.0.0
|
| 77 |
+
MarkupSafe==3.0.3
|
| 78 |
+
marshmallow==3.26.1
|
| 79 |
+
mdurl==0.1.2
|
| 80 |
+
mpmath==1.3.0
|
| 81 |
+
multidict==6.7.0
|
| 82 |
+
mypy_extensions==1.1.0
|
| 83 |
+
narwhals==2.9.0
|
| 84 |
+
networkx==3.5
|
| 85 |
+
nltk==3.9.2
|
| 86 |
+
numpy==2.3.4
|
| 87 |
+
openai==2.6.0
|
| 88 |
+
orjson==3.11.3
|
| 89 |
+
ormsgpack==1.11.0
|
| 90 |
+
packaging==25.0
|
| 91 |
+
pandas==2.3.3
|
| 92 |
+
patchright==1.55.2
|
| 93 |
+
pillow==11.3.0
|
| 94 |
+
playwright==1.55.0
|
| 95 |
+
propcache==0.4.1
|
| 96 |
+
protobuf==6.33.0
|
| 97 |
+
psutil==7.1.1
|
| 98 |
+
py-cpuinfo==9.0.0
|
| 99 |
+
pyarrow==21.0.0
|
| 100 |
+
pycparser==2.23
|
| 101 |
+
pydantic==2.12.3
|
| 102 |
+
pydantic-settings==2.11.0
|
| 103 |
+
pydantic_core==2.41.4
|
| 104 |
+
pydeck==0.9.1
|
| 105 |
+
pyee==13.0.0
|
| 106 |
+
Pygments==2.19.2
|
| 107 |
+
pyOpenSSL==25.3.0
|
| 108 |
+
python-dateutil==2.9.0.post0
|
| 109 |
+
python-dotenv==1.1.1
|
| 110 |
+
pytz==2025.2
|
| 111 |
+
PyYAML==6.0.3
|
| 112 |
+
rank-bm25==0.2.2
|
| 113 |
+
referencing==0.37.0
|
| 114 |
+
regex==2025.10.23
|
| 115 |
+
requests==2.32.5
|
| 116 |
+
requests-toolbelt==1.0.0
|
| 117 |
+
rich==14.2.0
|
| 118 |
+
rpds-py==0.27.1
|
| 119 |
+
rtree==1.4.1
|
| 120 |
+
safetensors==0.6.2
|
| 121 |
+
scikit-learn==1.7.2
|
| 122 |
+
scipy==1.16.2
|
| 123 |
+
sentence-transformers==5.1.2
|
| 124 |
+
setuptools==80.9.0
|
| 125 |
+
shapely==2.1.2
|
| 126 |
+
six==1.17.0
|
| 127 |
+
smmap==5.0.2
|
| 128 |
+
sniffio==1.3.1
|
| 129 |
+
snowballstemmer==2.2.0
|
| 130 |
+
soupsieve==2.8
|
| 131 |
+
SQLAlchemy==2.0.44
|
| 132 |
+
starlette==0.48.0
|
| 133 |
+
streamlit==1.50.0
|
| 134 |
+
sympy==1.14.0
|
| 135 |
+
tenacity==9.1.2
|
| 136 |
+
tf-playwright-stealth==1.2.0
|
| 137 |
+
threadpoolctl==3.6.0
|
| 138 |
+
tiktoken==0.12.0
|
| 139 |
+
tokenizers==0.22.1
|
| 140 |
+
toml==0.10.2
|
| 141 |
+
torch==2.9.0
|
| 142 |
+
tornado==6.5.2
|
| 143 |
+
tqdm==4.67.1
|
| 144 |
+
transformers==4.57.1
|
| 145 |
+
trimesh==4.8.3
|
| 146 |
+
typing-inspect==0.9.0
|
| 147 |
+
typing-inspection==0.4.2
|
| 148 |
+
typing_extensions==4.15.0
|
| 149 |
+
tzdata==2025.2
|
| 150 |
+
urllib3==2.5.0
|
| 151 |
+
uvicorn==0.38.0
|
| 152 |
+
uvloop==0.22.1
|
| 153 |
+
watchfiles==1.1.1
|
| 154 |
+
websockets==15.0.1
|
| 155 |
+
xxhash==3.6.0
|
| 156 |
+
yarl==1.22.0
|
| 157 |
+
zipp==3.23.0
|
| 158 |
+
zstandard==0.25.0
|
worker.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# worker.py
|
| 2 |
+
import os
|
| 3 |
+
import asyncio
|
| 4 |
+
import psutil
|
| 5 |
+
from urllib.parse import urlparse, urlunparse
|
| 6 |
+
from typing import List
|
| 7 |
+
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
| 8 |
+
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
| 9 |
+
from crawl4ai.content_filter_strategy import PruningContentFilter
|
| 10 |
+
import traceback
|
| 11 |
+
|
| 12 |
+
# ------------------------------
|
| 13 |
+
# File paths & config
|
| 14 |
+
# ------------------------------
|
| 15 |
+
__location__ = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
batch = 32 # max concurrent crawls
|
| 17 |
+
goto_timeout = 60_000 # 1 minutes
|
| 18 |
+
|
| 19 |
+
# ------------------------------
|
| 20 |
+
# Utility functions
|
| 21 |
+
# ------------------------------
|
| 22 |
+
def normalize_url(url: str) -> str:
|
| 23 |
+
"""Normalize URL to avoid duplicates."""
|
| 24 |
+
parsed = urlparse(url)
|
| 25 |
+
return urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
|
| 26 |
+
|
| 27 |
+
async def get_internal_urls(url_set: set, visited: set, crawler) -> set:
|
| 28 |
+
"""Collect internal links from a page."""
|
| 29 |
+
internal_urls = crawler.links.get("internal", [])
|
| 30 |
+
for link in internal_urls:
|
| 31 |
+
href = link.get("href")
|
| 32 |
+
if href and href.startswith("http"):
|
| 33 |
+
normalized_href = normalize_url(href)
|
| 34 |
+
if normalized_href not in visited:
|
| 35 |
+
url_set.add(normalized_href)
|
| 36 |
+
return url_set
|
| 37 |
+
|
| 38 |
+
# ------------------------------
|
| 39 |
+
# Core crawling function
|
| 40 |
+
# ------------------------------
|
| 41 |
+
async def crawl_parallel(urls: List[str] | str, file_path: str, max_concurrent: int = batch):
|
| 42 |
+
"""Crawl multiple URLs asynchronously with retries, save pages, and track failures."""
|
| 43 |
+
text_pages = set()
|
| 44 |
+
not_visited = set(urls if isinstance(urls, list) else [urls])
|
| 45 |
+
visited = set()
|
| 46 |
+
retry = set()
|
| 47 |
+
failed = set()
|
| 48 |
+
was_str = isinstance(urls, str)
|
| 49 |
+
n = 1
|
| 50 |
+
|
| 51 |
+
os.makedirs(file_path, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
process = psutil.Process()
|
| 54 |
+
peak_memory = 0
|
| 55 |
+
def log_memory(prefix: str = ""):
|
| 56 |
+
nonlocal peak_memory
|
| 57 |
+
current_mem = process.memory_info().rss
|
| 58 |
+
peak_memory = max(peak_memory, current_mem)
|
| 59 |
+
print(f"{prefix} Memory: {current_mem // (1024*1024)} MB | Peak: {peak_memory // (1024*1024)} MB")
|
| 60 |
+
|
| 61 |
+
# Browser & crawler config
|
| 62 |
+
browser_config = BrowserConfig(
|
| 63 |
+
headless=True,
|
| 64 |
+
verbose=False,
|
| 65 |
+
extra_args=["--disable-gpu", "--disable-dev-shm-usage", "--no-sandbox"],
|
| 66 |
+
text_mode=True
|
| 67 |
+
)
|
| 68 |
+
crawl_config = CrawlerRunConfig(
|
| 69 |
+
cache_mode=CacheMode.BYPASS,
|
| 70 |
+
markdown_generator=DefaultMarkdownGenerator(
|
| 71 |
+
content_filter=PruningContentFilter(threshold=0.6),
|
| 72 |
+
options={"ignore_links": True}
|
| 73 |
+
),
|
| 74 |
+
page_timeout=goto_timeout
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
crawler = AsyncWebCrawler(config=browser_config)
|
| 78 |
+
await crawler.start()
|
| 79 |
+
print("\n=== Starting robust parallel crawling ===")
|
| 80 |
+
|
| 81 |
+
async def safe_crawl(url, session_id):
|
| 82 |
+
"""Crawl a URL safely, return result or None."""
|
| 83 |
+
try:
|
| 84 |
+
result = await crawler.arun(url=url, config=crawl_config, session_id=session_id)
|
| 85 |
+
return result
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"[WARN] Failed to crawl {url}: {e}")
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
while not_visited:
|
| 92 |
+
urls_batch = list(not_visited)[:max_concurrent]
|
| 93 |
+
tasks = [safe_crawl(url, f"session_{i}") for i, url in enumerate(urls_batch)]
|
| 94 |
+
|
| 95 |
+
log_memory(prefix=f"Before batch {n}: ")
|
| 96 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 97 |
+
log_memory(prefix=f"After batch {n}: ")
|
| 98 |
+
|
| 99 |
+
for url, result in zip(urls_batch, results):
|
| 100 |
+
if isinstance(result, Exception) or result is None or not getattr(result, "success", False):
|
| 101 |
+
if url not in retry:
|
| 102 |
+
retry.add(url)
|
| 103 |
+
print(f"[INFO] Retry scheduled for {url}")
|
| 104 |
+
else:
|
| 105 |
+
failed.add(url)
|
| 106 |
+
not_visited.discard(url)
|
| 107 |
+
visited.add(url)
|
| 108 |
+
print(f"[ERROR] Crawling failed for {url} after retry")
|
| 109 |
+
else:
|
| 110 |
+
text_pages.add(result.markdown.fit_markdown)
|
| 111 |
+
if was_str:
|
| 112 |
+
internal_urls = result.links.get("internal", [])
|
| 113 |
+
for link in internal_urls:
|
| 114 |
+
href = link.get("href")
|
| 115 |
+
if href and href.startswith("http"):
|
| 116 |
+
normalized_href = normalize_url(href)
|
| 117 |
+
if normalized_href not in visited:
|
| 118 |
+
not_visited.add(normalized_href)
|
| 119 |
+
visited.add(url)
|
| 120 |
+
retry.discard(url)
|
| 121 |
+
not_visited.discard(url)
|
| 122 |
+
n += 1
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
traceback.print_exc()
|
| 126 |
+
print(e)
|
| 127 |
+
finally:
|
| 128 |
+
await crawler.close()
|
| 129 |
+
log_memory(prefix="Final: ")
|
| 130 |
+
|
| 131 |
+
# Save pages
|
| 132 |
+
pages = [p for p in text_pages if p.strip()]
|
| 133 |
+
for i, page in enumerate(pages):
|
| 134 |
+
with open(os.path.join(file_path, f"page_{i+1}.txt"), "w", encoding="utf-8") as f:
|
| 135 |
+
f.write(page)
|
| 136 |
+
|
| 137 |
+
print(f"\nSummary:")
|
| 138 |
+
print(f" - Successfully crawled pages: {len(pages)}")
|
| 139 |
+
print(f" - Failed URLs: {len(failed)} -> {failed}")
|
| 140 |
+
print(f"Peak memory usage: {peak_memory // (1024*1024)} MB")
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"success_count": len(pages),
|
| 144 |
+
"failed_urls": list(failed),
|
| 145 |
+
"peak_memory_MB": peak_memory // (1024*1024)
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# ------------------------------
|
| 149 |
+
# Public scrape function
|
| 150 |
+
# ------------------------------
|
| 151 |
+
async def scrape_website(urls: str | list, file_path: str):
|
| 152 |
+
"""Wrapper to start crawling and return summary."""
|
| 153 |
+
summary = await crawl_parallel(urls, file_path, max_concurrent=batch)
|
| 154 |
+
return summary
|