Spaces:
Running
Running
import sqlite3 | |
from functools import partial | |
from itertools import islice | |
from pathlib import Path | |
import click | |
BATCH_SIZE=1024 | |
def main(input_dir: Path, output: Path, image_host: str, explicit:bool): | |
connection = sqlite3.connect(output) | |
try: | |
_main_with_connection(input_dir, connection, image_host, explicit) | |
finally: | |
connection.close() | |
def _main_with_connection(input_dir: Path, connection: sqlite3.Connection, image_host: str=None, explicit=True): | |
connection.execute("CREATE TABLE IF NOT EXISTS " | |
" captions(image_key text PRIMARY KEY, caption text NOT NULL);") | |
if image_host: | |
connection.execute(f""" | |
CREATE VIEW IF NOT EXISTS images AS | |
SELECT {sql_quote(connection, image_host)} || image_key || '.jpg' AS image, | |
caption, | |
rowid | |
FROM captions | |
""") | |
text_files = input_dir.glob("*.txt") | |
with click.progressbar(chunked(text_files, BATCH_SIZE)) as progress: | |
for batch in progress: | |
text_file: Path | |
pairs = ((text_file.stem, text_file.read_text()) | |
for text_file in batch) | |
with connection: | |
connection.executemany("INSERT INTO captions(image_key, caption) " | |
"VALUES(?, ?) ", pairs) | |
if not explicit: | |
ratings = ["rating:unsafe", "rating:explicit", "rating:mature", "meta:nsfw", | |
"subreddit:%nsfw"] | |
for rating in ratings: | |
with connection: | |
c = connection.execute("DELETE FROM captions WHERE caption LIKE ?", | |
(f"%{rating}%",)) | |
print(f"Removed {c.rowcount} {rating} rows") | |
with connection: | |
# Add full-text search index | |
connection.execute("""CREATE VIRTUAL TABLE | |
captions_fts USING | |
fts5(caption, image_key UNINDEXED, content=captions) | |
""") | |
connection.execute(""" | |
INSERT INTO "captions_fts" (rowid, image_key, caption) | |
SELECT rowid, image_key, caption | |
FROM captions | |
""") | |
def chunked(iterable, n): | |
return iter(partial(take, n, iter(iterable)), []) | |
def take(n, iterable): | |
return list(islice(iterable, n)) | |
def sql_quote(connection, value: str) -> str: | |
""" | |
Apply SQLite string quoting to a value, including wrapping it in single quotes. | |
:param value: String to quote | |
""" | |
# Normally we would use .execute(sql, [params]) for escaping, but | |
# occasionally that isn't available - most notable when we need | |
# to include a "... DEFAULT 'value'" in a column definition. | |
return connection.execute( | |
# Use SQLite itself to correctly escape this string: | |
"SELECT quote(:value)", | |
{"value": value}, | |
).fetchone()[0] | |
if __name__ == "__main__": | |
main() | |