p / Client /Scripts /scan_existing_exif.py
q6's picture
S
52fadc8
import os
import sqlite3
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
from PIL import Image
from tqdm import tqdm
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
IMAGES_DIR = os.path.join(ROOT_DIR, "images")
STASH_DIR = os.path.join(IMAGES_DIR, "Stash")
DB_PATH = os.path.join(ROOT_DIR, "db.sqlite")
MAX_WORKERS = min(16, os.cpu_count() or 8)
EXIF_METADATA_MAX_BYTES = 512
EXIF_TYPE_ORDER = ("novelai", "sd", "comfy", "mj", "celsys", "photoshop", "stealth")
EXIF_TYPE_TO_CODE = {name: idx + 1 for idx, name in enumerate(EXIF_TYPE_ORDER)}
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
def open_db(path: str) -> sqlite3.Connection:
conn = sqlite3.connect(path)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS pixif_cache (
post_id TEXT PRIMARY KEY,
url TEXT,
exif_type INTEGER
)
"""
)
conn.commit()
ensure_db_schema(conn)
return conn
def ensure_db_schema(conn: sqlite3.Connection) -> None:
columns = [row[1] for row in conn.execute("PRAGMA table_info(pixif_cache)")]
if "exif_type" not in columns:
conn.execute("ALTER TABLE pixif_cache ADD COLUMN exif_type INTEGER")
conn.commit()
def determine_exif_type(metadata: Optional[bytes]) -> Optional[str]:
if metadata is None:
return None
if metadata == b"TitleAI generated image":
return "novelai"
if metadata.startswith(b"parameter"):
return "sd"
if b'{"' in metadata:
return "comfy"
if metadata.startswith(b"SoftwareCelsys"):
return "celsys"
return "photoshop"
def exif_type_to_code(exif_type: Optional[str]) -> Optional[int]:
if not exif_type:
return None
return EXIF_TYPE_TO_CODE.get(exif_type)
def parse_png_metadata(data: bytes) -> Optional[bytes]:
index = 8
while index < len(data):
if index + 8 > len(data):
break
chunk_len = int.from_bytes(data[index:index + 4], "big")
chunk_type = data[index + 4:index + 8]
index += 8
if chunk_type == b"tEXt":
content = data[index:index + chunk_len]
return content.replace(b"\0", b"")
if chunk_type == b"iTXt":
content = data[index:index + chunk_len]
return content.strip()
index += chunk_len + 4
return None
def parse_png_metadata_file(path: str) -> Optional[bytes]:
try:
with open(path, "rb") as handle:
head = handle.read(EXIF_METADATA_MAX_BYTES)
if not head.startswith(PNG_SIGNATURE):
return None
return parse_png_metadata(head)
except Exception:
return None
def byteize(alpha: np.ndarray) -> np.ndarray:
alpha = alpha.T.reshape((-1,))
alpha = alpha[:(alpha.shape[0] // 8) * 8]
alpha = np.bitwise_and(alpha, 1)
alpha = alpha.reshape((-1, 8))
alpha = np.packbits(alpha, axis=1)
return alpha
class LSBExtractor:
def __init__(self, alpha: np.ndarray) -> None:
self.data = byteize(alpha)
self.pos = 0
def get_next_n_bytes(self, n: int) -> bytearray:
n_bytes = self.data[self.pos:self.pos + n]
self.pos += n
return bytearray(n_bytes)
def read_32bit_integer(self) -> Optional[int]:
bytes_list = self.get_next_n_bytes(4)
if len(bytes_list) == 4:
return int.from_bytes(bytes_list, byteorder="big")
return None
def extract_stealth_metadata(image: Image.Image) -> bool:
if "A" not in image.getbands():
raise AssertionError("image format")
alpha = np.array(image.getchannel("A"))
reader = LSBExtractor(alpha)
magic = "stealth_pngcomp"
read_magic = reader.get_next_n_bytes(len(magic)).decode("utf-8")
if magic != read_magic:
raise AssertionError("magic number")
read_len = reader.read_32bit_integer()
if read_len is None:
raise AssertionError("length missing")
return True
def has_stealth_png_path(path: str) -> bool:
try:
with Image.open(path) as image:
return extract_stealth_metadata(image)
except Exception:
return False
def detect_exif_code_from_path(path: str) -> Optional[int]:
metadata = parse_png_metadata_file(path)
exif_type = determine_exif_type(metadata)
code = exif_type_to_code(exif_type)
if code is not None:
return code
if has_stealth_png_path(path):
return EXIF_TYPE_TO_CODE.get("stealth")
return None
def fetch_pending_post_ids(conn: sqlite3.Connection) -> List[str]:
rows = conn.execute(
"""
SELECT post_id
FROM pixif_cache
WHERE exif_type IS NULL
AND COALESCE(url, '') != ''
"""
).fetchall()
return [str(row[0]) for row in rows]
def update_exif_types(conn: sqlite3.Connection, rows: Sequence[Tuple[int, str]]) -> None:
if not rows:
return
conn.executemany(
"""
UPDATE pixif_cache SET exif_type = ?
WHERE post_id = ?
""",
rows,
)
def detect_exif_codes_from_files(
post_ids: Sequence[str],
stash_dir: str,
max_workers: int = MAX_WORKERS,
) -> Dict[str, Optional[int]]:
if not post_ids:
return {}
results: Dict[str, Optional[int]] = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
detect_exif_code_from_path,
os.path.join(stash_dir, f"{post_id}.png"),
): post_id
for post_id in post_ids
}
with tqdm(total=len(futures), unit="image", desc="Scanning exif") as pbar:
for future in as_completed(futures):
post_id = futures[future]
try:
code = future.result()
except Exception:
code = None
results[post_id] = code
pbar.update(1)
return results
def main() -> int:
os.makedirs(STASH_DIR, exist_ok=True)
conn = open_db(DB_PATH)
try:
post_ids = fetch_pending_post_ids(conn)
if not post_ids:
print("No pending rows.")
return 0
existing = [post_id for post_id in post_ids if os.path.exists(os.path.join(STASH_DIR, f"{post_id}.png"))]
if not existing:
print("No matching images in stash.")
return 0
results = detect_exif_codes_from_files(existing, STASH_DIR)
rows = [
(exif_type, post_id)
for post_id, exif_type in results.items()
if exif_type is not None
]
if rows:
with conn:
update_exif_types(conn, rows)
print(f"Updated {len(rows)} rows.")
return 0
finally:
conn.close()
if __name__ == "__main__":
raise SystemExit(main())