ZeroCool94's picture
Update src/textdir2sql/loading.py
60fc52e
import sqlite3
from functools import partial
from itertools import islice
from pathlib import Path
import click
BATCH_SIZE=1024
@click.command()
@click.argument('input_dir', type=click.Path(exists=True, file_okay=False, path_type=Path))
@click.argument('output', type=click.Path(dir_okay=False, writable=True, path_type=Path))
@click.option('--image-host', help="base URL of images")
@click.option('--explicit/--no-explicit', default=True)
def main(input_dir: Path, output: Path, image_host: str, explicit:bool):
connection = sqlite3.connect(output)
try:
_main_with_connection(input_dir, connection, image_host, explicit)
finally:
connection.close()
def _main_with_connection(input_dir: Path, connection: sqlite3.Connection, image_host: str=None, explicit=True):
connection.execute("CREATE TABLE IF NOT EXISTS "
" captions(image_key text PRIMARY KEY, caption text NOT NULL);")
if image_host:
connection.execute(f"""
CREATE VIEW IF NOT EXISTS images AS
SELECT {sql_quote(connection, image_host)} || image_key || '.jpg' AS image,
caption,
rowid
FROM captions
""")
text_files = input_dir.glob("*.txt")
with click.progressbar(chunked(text_files, BATCH_SIZE)) as progress:
for batch in progress:
text_file: Path
pairs = ((text_file.stem, text_file.read_text())
for text_file in batch)
with connection:
connection.executemany("INSERT INTO captions(image_key, caption) "
"VALUES(?, ?) ", pairs)
if not explicit:
ratings = ["rating:unsafe", "rating:explicit", "rating:mature", "meta:nsfw",
"subreddit:%nsfw"]
for rating in ratings:
with connection:
c = connection.execute("DELETE FROM captions WHERE caption LIKE ?",
(f"%{rating}%",))
print(f"Removed {c.rowcount} {rating} rows")
with connection:
# Add full-text search index
connection.execute("""CREATE VIRTUAL TABLE
captions_fts USING
fts5(caption, image_key UNINDEXED, content=captions)
""")
connection.execute("""
INSERT INTO "captions_fts" (rowid, image_key, caption)
SELECT rowid, image_key, caption
FROM captions
""")
def chunked(iterable, n):
return iter(partial(take, n, iter(iterable)), [])
def take(n, iterable):
return list(islice(iterable, n))
def sql_quote(connection, value: str) -> str:
"""
Apply SQLite string quoting to a value, including wrapping it in single quotes.
:param value: String to quote
"""
# Normally we would use .execute(sql, [params]) for escaping, but
# occasionally that isn't available - most notable when we need
# to include a "... DEFAULT 'value'" in a column definition.
return connection.execute(
# Use SQLite itself to correctly escape this string:
"SELECT quote(:value)",
{"value": value},
).fetchone()[0]
if __name__ == "__main__":
main()