Spaces:
Sleeping
Sleeping
import random | |
import pandas as pd | |
import gradio as gr | |
from typing import Dict, Optional | |
import unibox as ub | |
# Store current dataset in a global dict so it persists across Gradio calls. | |
CURRENT_DATASET = { | |
"id": None, | |
"df": None | |
} | |
rating_map = { | |
"g": "general", | |
"s": "sensitive", | |
"q": "questionable", | |
"e": "explicit" | |
} | |
def load_dataset_if_needed(dataset_id: str): | |
""" | |
Checks if dataset_id is different from what's currently loaded. | |
If so, loads from HF again and updates CURRENT_DATASET. | |
""" | |
if CURRENT_DATASET["id"] != dataset_id: | |
df = ub.loads(f"hf://{dataset_id}").to_pandas() | |
CURRENT_DATASET["id"] = dataset_id | |
CURRENT_DATASET["df"] = df | |
def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str: | |
""" | |
1girl long_hair blush -> 1girl, long_hair, blush | |
""" | |
tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i] | |
if shuffle: | |
random.shuffle(tags_list) | |
return ", ".join(tags_list) | |
def get_tags_dict(df_row: pd.Series) -> dict: | |
""" | |
Returns a dict with rating/artist/character/copyright/general/meta | |
plus numeric score. | |
""" | |
rating = df_row["rating"] | |
artist = df_row["tag_string_artist"] | |
character = df_row["tag_string_character"] | |
copyright_ = df_row["tag_string_copyright"] | |
general = df_row["tag_string_general"] | |
meta = df_row["tag_string_meta"] | |
score = df_row["score"] | |
rating_str = rating_map.get(rating, "") | |
artist_str = artist if artist else "" | |
character_str = convert_dbr_tag_string(character) if character else "" | |
copyright_str = f"copyright:{copyright_}" if copyright_ else "" | |
general_str = convert_dbr_tag_string(general) if general else "" | |
meta_str = convert_dbr_tag_string(meta) if meta else "" | |
_score = str(score) if score else "" | |
return { | |
"rating_str": rating_str, | |
"artist_str": artist_str, | |
"character_str": character_str, | |
"copyright_str": copyright_str, | |
"general_str": general_str, | |
"meta_str": meta_str, | |
"score": _score, | |
} | |
def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str: | |
""" | |
Build a final comma-separated string (rating, artist, character, etc.). | |
""" | |
context = [] | |
if tags_dict["rating_str"]: | |
context.append(tags_dict["rating_str"]) | |
if tags_dict["artist_str"] and add_artist_tags: | |
context.append(f"artist:{tags_dict['artist_str']}") | |
if tags_dict["character_str"]: | |
context.append(tags_dict["character_str"]) | |
if tags_dict["copyright_str"]: | |
context.append(tags_dict["copyright_str"]) | |
if tags_dict["general_str"]: | |
context.append(tags_dict["general_str"]) | |
return ", ".join(context) | |
def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5, | |
tags_front: str = "", tags_back: str = "", | |
add_artist_tags: bool = True) -> list: | |
filtered_df = df.iloc[start_idx:end_idx] | |
captions = [] | |
for _, row in filtered_df.iterrows(): | |
tags = get_tags_dict(row) | |
caption_base = build_tags_from_tags_dict(tags, add_artist_tags) | |
# Combine front, base, back | |
pieces = [part for part in [tags_front, caption_base, tags_back] if part] | |
final_caption = ", ".join(pieces) | |
captions.append(final_caption) | |
return captions | |
def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list: | |
filtered_df = df.iloc[start_idx:end_idx] | |
return [row["large_file_url"] for _, row in filtered_df.iterrows()] | |
def gradio_interface( | |
dataset_id: str, | |
start_idx: int = 0, | |
display_count: int = 5, | |
tags_front: str = "", | |
tags_back: str = "", | |
add_artist_tags: bool = True | |
): | |
""" | |
1) Loads dataset if needed | |
2) Returns (DataFrame, Gallery, InfoMessage) | |
""" | |
# 1) Possibly reload | |
load_dataset_if_needed(dataset_id) | |
dset_df = CURRENT_DATASET["df"] | |
if dset_df is None: | |
return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}" | |
# 2) Figure out total length, clamp inputs | |
total_len = len(dset_df) | |
if total_len == 0: | |
return pd.DataFrame(), [], f"Dataset {dataset_id} is empty." | |
start_idx = max(start_idx, 0) | |
if start_idx >= total_len: | |
start_idx = total_len - 1 | |
end_idx = start_idx + display_count | |
if end_idx > total_len: | |
end_idx = total_len | |
# 3) Build results | |
idxs = range(start_idx, end_idx) | |
captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags) | |
previews = get_previews_for_rows(dset_df, start_idx, end_idx) | |
df_out = pd.DataFrame({"index": idxs, "Captions": captions}) | |
# 4) Build info string | |
info_msg = ( | |
f"**Current dataset:** {CURRENT_DATASET['id']} \n" | |
f"**Dataset length:** {total_len} \n" | |
f"**start_idx:** {start_idx}, **display_count:** {display_count}, " | |
f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', " | |
f"**add_artist_tags:** {add_artist_tags}" | |
) | |
return df_out, previews, info_msg | |
with gr.Blocks() as demo: | |
gr.Markdown("## Danbooru2025 Dataset Captions and Previews") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
dataset_id_input = gr.Textbox( | |
value="dataproc5/test-danbooru2025-tag-balanced-2k", | |
label="Dataset ID" | |
) | |
start_idx_input = gr.Number(value=500, label="Start Index") | |
display_count_input = gr.Slider( | |
value=5, minimum=1, maximum=50, step=1, | |
label="Number of Items" | |
) | |
tags_front_input = gr.Textbox(value="", label="Tags Front") | |
tags_back_input = gr.Textbox(value="", label="Tags Back") | |
add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True) | |
run_button = gr.Button("Get Captions & Previews") | |
with gr.Column(scale=2): | |
captions_df_out = gr.DataFrame(label="Captions") | |
previews_gallery_out = gr.Gallery(label="Previews", type="filepath") | |
info_textbox_out = gr.Markdown(value="") | |
run_button.click( | |
fn=gradio_interface, | |
inputs=[ | |
dataset_id_input, | |
start_idx_input, | |
display_count_input, | |
tags_front_input, | |
tags_back_input, | |
add_artist_tags_input | |
], | |
outputs=[ | |
captions_df_out, | |
previews_gallery_out, | |
info_textbox_out | |
] | |
) | |
demo.launch() |