Spaces:
Running
Running
import requests | |
import json | |
import time | |
import sys | |
import base64 | |
import os | |
from typing import Dict, Any | |
class Crawl4AiTester: | |
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): | |
self.base_url = base_url | |
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback | |
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} | |
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: | |
# Submit crawl job | |
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) | |
if response.status_code == 403: | |
raise Exception("API token is invalid or missing") | |
task_id = response.json()["task_id"] | |
print(f"Task ID: {task_id}") | |
# Poll for result | |
start_time = time.time() | |
while True: | |
if time.time() - start_time > timeout: | |
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") | |
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) | |
status = result.json() | |
if status["status"] == "failed": | |
print("Task failed:", status.get("error")) | |
raise Exception(f"Task failed: {status.get('error')}") | |
if status["status"] == "completed": | |
return status | |
time.sleep(2) | |
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: | |
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) | |
if response.status_code == 408: | |
raise TimeoutError("Task did not complete within server timeout") | |
response.raise_for_status() | |
return response.json() | |
def test_docker_deployment(version="basic"): | |
tester = Crawl4AiTester( | |
# base_url="http://localhost:11235" , | |
base_url="https://crawl4ai-sby74.ondigitalocean.app", | |
api_token="test" | |
) | |
print(f"Testing Crawl4AI Docker {version} version") | |
# Health check with timeout and retry | |
max_retries = 5 | |
for i in range(max_retries): | |
try: | |
health = requests.get(f"{tester.base_url}/health", timeout=10) | |
print("Health check:", health.json()) | |
break | |
except requests.exceptions.RequestException as e: | |
if i == max_retries - 1: | |
print(f"Failed to connect after {max_retries} attempts") | |
sys.exit(1) | |
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") | |
time.sleep(5) | |
# Test cases based on version | |
test_basic_crawl(tester) | |
test_basic_crawl(tester) | |
test_basic_crawl_sync(tester) | |
# if version in ["full", "transformer"]: | |
# test_cosine_extraction(tester) | |
# test_js_execution(tester) | |
# test_css_selector(tester) | |
# test_structured_extraction(tester) | |
# test_llm_extraction(tester) | |
# test_llm_with_ollama(tester) | |
# test_screenshot(tester) | |
def test_basic_crawl(tester: Crawl4AiTester): | |
print("\n=== Testing Basic Crawl ===") | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 10, | |
"session_id": "test" | |
} | |
result = tester.submit_and_wait(request) | |
print(f"Basic crawl result length: {len(result['result']['markdown'])}") | |
assert result["result"]["success"] | |
assert len(result["result"]["markdown"]) > 0 | |
def test_basic_crawl_sync(tester: Crawl4AiTester): | |
print("\n=== Testing Basic Crawl (Sync) ===") | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 10, | |
"session_id": "test" | |
} | |
result = tester.submit_sync(request) | |
print(f"Basic crawl result length: {len(result['result']['markdown'])}") | |
assert result['status'] == 'completed' | |
assert result['result']['success'] | |
assert len(result['result']['markdown']) > 0 | |
def test_js_execution(tester: Crawl4AiTester): | |
print("\n=== Testing JS Execution ===") | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 8, | |
"js_code": [ | |
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" | |
], | |
"wait_for": "article.tease-card:nth-child(10)", | |
"crawler_params": { | |
"headless": True | |
} | |
} | |
result = tester.submit_and_wait(request) | |
print(f"JS execution result length: {len(result['result']['markdown'])}") | |
assert result["result"]["success"] | |
def test_css_selector(tester: Crawl4AiTester): | |
print("\n=== Testing CSS Selector ===") | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 7, | |
"css_selector": ".wide-tease-item__description", | |
"crawler_params": { | |
"headless": True | |
}, | |
"extra": {"word_count_threshold": 10} | |
} | |
result = tester.submit_and_wait(request) | |
print(f"CSS selector result length: {len(result['result']['markdown'])}") | |
assert result["result"]["success"] | |
def test_structured_extraction(tester: Crawl4AiTester): | |
print("\n=== Testing Structured Extraction ===") | |
schema = { | |
"name": "Coinbase Crypto Prices", | |
"baseSelector": ".cds-tableRow-t45thuk", | |
"fields": [ | |
{ | |
"name": "crypto", | |
"selector": "td:nth-child(1) h2", | |
"type": "text", | |
}, | |
{ | |
"name": "symbol", | |
"selector": "td:nth-child(1) p", | |
"type": "text", | |
}, | |
{ | |
"name": "price", | |
"selector": "td:nth-child(2)", | |
"type": "text", | |
} | |
], | |
} | |
request = { | |
"urls": "https://www.coinbase.com/explore", | |
"priority": 9, | |
"extraction_config": { | |
"type": "json_css", | |
"params": { | |
"schema": schema | |
} | |
} | |
} | |
result = tester.submit_and_wait(request) | |
extracted = json.loads(result["result"]["extracted_content"]) | |
print(f"Extracted {len(extracted)} items") | |
print("Sample item:", json.dumps(extracted[0], indent=2)) | |
assert result["result"]["success"] | |
assert len(extracted) > 0 | |
def test_llm_extraction(tester: Crawl4AiTester): | |
print("\n=== Testing LLM Extraction ===") | |
schema = { | |
"type": "object", | |
"properties": { | |
"model_name": { | |
"type": "string", | |
"description": "Name of the OpenAI model." | |
}, | |
"input_fee": { | |
"type": "string", | |
"description": "Fee for input token for the OpenAI model." | |
}, | |
"output_fee": { | |
"type": "string", | |
"description": "Fee for output token for the OpenAI model." | |
} | |
}, | |
"required": ["model_name", "input_fee", "output_fee"] | |
} | |
request = { | |
"urls": "https://openai.com/api/pricing", | |
"priority": 8, | |
"extraction_config": { | |
"type": "llm", | |
"params": { | |
"provider": "openai/gpt-4o-mini", | |
"api_token": os.getenv("OPENAI_API_KEY"), | |
"schema": schema, | |
"extraction_type": "schema", | |
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" | |
} | |
}, | |
"crawler_params": {"word_count_threshold": 1} | |
} | |
try: | |
result = tester.submit_and_wait(request) | |
extracted = json.loads(result["result"]["extracted_content"]) | |
print(f"Extracted {len(extracted)} model pricing entries") | |
print("Sample entry:", json.dumps(extracted[0], indent=2)) | |
assert result["result"]["success"] | |
except Exception as e: | |
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") | |
def test_llm_with_ollama(tester: Crawl4AiTester): | |
print("\n=== Testing LLM with Ollama ===") | |
schema = { | |
"type": "object", | |
"properties": { | |
"article_title": { | |
"type": "string", | |
"description": "The main title of the news article" | |
}, | |
"summary": { | |
"type": "string", | |
"description": "A brief summary of the article content" | |
}, | |
"main_topics": { | |
"type": "array", | |
"items": {"type": "string"}, | |
"description": "Main topics or themes discussed in the article" | |
} | |
} | |
} | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 8, | |
"extraction_config": { | |
"type": "llm", | |
"params": { | |
"provider": "ollama/llama2", | |
"schema": schema, | |
"extraction_type": "schema", | |
"instruction": "Extract the main article information including title, summary, and main topics." | |
} | |
}, | |
"extra": {"word_count_threshold": 1}, | |
"crawler_params": {"verbose": True} | |
} | |
try: | |
result = tester.submit_and_wait(request) | |
extracted = json.loads(result["result"]["extracted_content"]) | |
print("Extracted content:", json.dumps(extracted, indent=2)) | |
assert result["result"]["success"] | |
except Exception as e: | |
print(f"Ollama extraction test failed: {str(e)}") | |
def test_cosine_extraction(tester: Crawl4AiTester): | |
print("\n=== Testing Cosine Extraction ===") | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 8, | |
"extraction_config": { | |
"type": "cosine", | |
"params": { | |
"semantic_filter": "business finance economy", | |
"word_count_threshold": 10, | |
"max_dist": 0.2, | |
"top_k": 3 | |
} | |
} | |
} | |
try: | |
result = tester.submit_and_wait(request) | |
extracted = json.loads(result["result"]["extracted_content"]) | |
print(f"Extracted {len(extracted)} text clusters") | |
print("First cluster tags:", extracted[0]["tags"]) | |
assert result["result"]["success"] | |
except Exception as e: | |
print(f"Cosine extraction test failed: {str(e)}") | |
def test_screenshot(tester: Crawl4AiTester): | |
print("\n=== Testing Screenshot ===") | |
request = { | |
"urls": "https://www.nbcnews.com/business", | |
"priority": 5, | |
"screenshot": True, | |
"crawler_params": { | |
"headless": True | |
} | |
} | |
result = tester.submit_and_wait(request) | |
print("Screenshot captured:", bool(result["result"]["screenshot"])) | |
if result["result"]["screenshot"]: | |
# Save screenshot | |
screenshot_data = base64.b64decode(result["result"]["screenshot"]) | |
with open("test_screenshot.jpg", "wb") as f: | |
f.write(screenshot_data) | |
print("Screenshot saved as test_screenshot.jpg") | |
assert result["result"]["success"] | |
if __name__ == "__main__": | |
version = sys.argv[1] if len(sys.argv) > 1 else "basic" | |
# version = "full" | |
test_docker_deployment(version) |