import sqlite3 from functools import partial from itertools import islice from pathlib import Path import click BATCH_SIZE=1024 @click.command() @click.argument('input_dir', type=click.Path(exists=True, file_okay=False, path_type=Path)) @click.argument('output', type=click.Path(dir_okay=False, writable=True, path_type=Path)) @click.option('--image-host', help="base URL of images") @click.option('--explicit/--no-explicit', default=False) def main(input_dir: Path, output: Path, image_host: str, explicit:bool): connection = sqlite3.connect(output) try: _main_with_connection(input_dir, connection, image_host, explicit) finally: connection.close() def _main_with_connection(input_dir: Path, connection: sqlite3.Connection, image_host: str=None, explicit=False): connection.execute("CREATE TABLE IF NOT EXISTS " " captions(image_key text PRIMARY KEY, caption text NOT NULL);") if image_host: connection.execute(f""" CREATE VIEW IF NOT EXISTS images AS SELECT {sql_quote(connection, image_host)} || image_key || '.jpg' AS image, caption, rowid FROM captions """) text_files = input_dir.glob("*.txt") with click.progressbar(chunked(text_files, BATCH_SIZE)) as progress: for batch in progress: text_file: Path pairs = ((text_file.stem, text_file.read_text()) for text_file in batch) with connection: connection.executemany("INSERT INTO captions(image_key, caption) " "VALUES(?, ?) ", pairs) if not explicit: ratings = ["rating:unsafe", "rating:explicit", "rating:mature", "meta:nsfw", "subreddit:%nsfw"] for rating in ratings: with connection: c = connection.execute("DELETE FROM captions WHERE caption LIKE ?", (f"%{rating}%",)) print(f"Removed {c.rowcount} {rating} rows") with connection: # Add full-text search index connection.execute("""CREATE VIRTUAL TABLE captions_fts USING fts5(caption, image_key UNINDEXED, content=captions) """) connection.execute(""" INSERT INTO "captions_fts" (rowid, image_key, caption) SELECT rowid, image_key, caption FROM captions """) def chunked(iterable, n): return iter(partial(take, n, iter(iterable)), []) def take(n, iterable): return list(islice(iterable, n)) def sql_quote(connection, value: str) -> str: """ Apply SQLite string quoting to a value, including wrapping it in single quotes. :param value: String to quote """ # Normally we would use .execute(sql, [params]) for escaping, but # occasionally that isn't available - most notable when we need # to include a "... DEFAULT 'value'" in a column definition. return connection.execute( # Use SQLite itself to correctly escape this string: "SELECT quote(:value)", {"value": value}, ).fetchone()[0] if __name__ == "__main__": main()