sa / scrape.py
scribbyotx's picture
Upload 5 files
6d7c565 verified
raw
history blame
No virus
2.99 kB
import io
import os
import re
import pandas as pd
from typing import Any, Dict, List
import requests
from PIL import Image as PILImage
from scraper import ScraperBot, ScraperBotConfig
from helpers import starts_with_quotes, get_start_end_quotes
from dataclasses import dataclass
from datasets import Image
@dataclass(frozen=True)
class HFDatasetScheme:
caption: str
image: Image(decode=True)
link: str
message_id: str
timestamp: str
url_pattern = re.compile(r'https?://\S+')
def parse_fn(message: Dict[str, Any]) -> List[HFDatasetScheme]:
"""Parses a message into a list of Hugging Face Dataset Schemes.
Parameters
----------
message : Dict[str, Any]
The message to parse.
Returns
-------
List[HFDatasetScheme]
A list of Hugging Face Dataset Schemes.
"""
content = message["content"]
(first_quote_index, last_quote_index) = get_start_end_quotes(content)
# Extract the text between the first and last quotes to get the complete prompt
prompt = content[first_quote_index + 1:last_quote_index].strip()
image_urls = url_pattern.findall(content)
timestamp = message["timestamp"]
message_id = message["id"]
return [HFDatasetScheme(caption=prompt, image=None, link=image_url, message_id=message_id, timestamp=timestamp)
for image_url in image_urls]
def condition_fn(message: Dict[str, Any]) -> bool:
"""Checks if a message meets the condition to be parsed.
Parameters
----------
message : Dict[str, Any]
The message to check.
Returns
-------
bool
True if the message meets the condition, False otherwise.
"""
return url_pattern.search(message["content"]) and starts_with_quotes(message["content"])
def prepare_dataset(messages: List[HFDatasetScheme]) -> pd.DataFrame:
return pd.DataFrame(
{
"caption": [msg.caption for msg in messages],
"image": [
None for msg in messages
], # Initialize to None, will be filled in later
"link": [
msg.link for msg in messages
], # will maintain just because we use it to filter
"message_id": [msg.message_id for msg in messages],
"timestamp": [msg.timestamp for msg in messages],
}
)
def get_image(link: str) -> bytes:
image = PILImage.open(requests.get(link, stream=True).raw).convert("RGB")
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
return {"bytes": img_byte_arr.getvalue(), "path": None}
if __name__ == "__main__":
config_path = os.path.join(os.path.dirname(__file__), "config.json")
config = ScraperBotConfig.from_json(config_path)
bot = ScraperBot(config=config, HFDatasetScheme=HFDatasetScheme, prepare_dataset=prepare_dataset, parse_fn=parse_fn, condition_fn=condition_fn, download_fn=get_image)
bot.scrape(fetch_all=os.environ.get("FETCH_ALL", "false").lower() == "true")