booru-images / app.py
Johnny-Z's picture
Upload app.py
811568b verified
import streamlit as st
import pandas as pd
import pyarrow.parquet as pq
import pyarrow.dataset as ds
import time
import os
import plotly.graph_objects as go
import gc
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub import login
st.set_page_config(layout="wide")
hf_token = os.getenv('HF_TOKEN')
hf_repo = os.getenv('HF_REPO')
login(token=hf_token)
data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0)
if data_source == "Danbooru":
parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE1'), repo_type="dataset")
elif data_source == "Gelbooru":
parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE2'), repo_type="dataset")
elif data_source == "Rule 34":
parquet_file = hf_hub_download(repo_id=hf_repo, filename=os.getenv('PARQUET_FILE3'), repo_type="dataset")
@st.cache_resource
def load_parquet_metadata(parquet_file):
try:
parquet_dataset = pq.ParquetFile(parquet_file)
metadata = parquet_dataset.metadata
num_rows = metadata.num_rows
sample_df = next(parquet_dataset.iter_batches(batch_size=10)).to_pandas()
if 'post_id' in sample_df.columns:
try:
min_post_id = float('inf')
max_post_id = float('-inf')
for i in range(parquet_dataset.metadata.num_row_groups):
row_group = parquet_dataset.metadata.row_group(i)
for j in range(row_group.num_columns):
col = row_group.column(j)
if col.path_in_schema == 'post_id':
stats = col.statistics
if stats is not None:
min_post_id = min(min_post_id, stats.min)
max_post_id = max(max_post_id, stats.max)
if min_post_id == float('inf') or max_post_id == float('-inf'):
raise ValueError("Invalid post_id range")
except Exception as e:
st.warning(f"Unable to get post_id range from statistics: {str(e)}")
min_post_id = float('inf')
max_post_id = float('-inf')
with pq.ParquetReader(parquet_file) as reader:
first_batch = next(reader.iter_batches(batch_size=1000))
first_df = first_batch.to_pandas()
batch_min = first_df['post_id'].min()
batch_max = first_df['post_id'].max()
min_post_id = min(min_post_id, batch_min)
max_post_id = max(max_post_id, batch_max)
num_row_groups = reader.num_row_groups
sample_indices = [0, num_row_groups//2, num_row_groups-1]
for idx in sample_indices:
if idx >= 0 and idx < num_row_groups:
batch = reader.read_row_group(idx).to_pandas()
batch_min = batch['post_id'].min()
batch_max = batch['post_id'].max()
min_post_id = min(min_post_id, batch_min)
max_post_id = max(max_post_id, batch_max)
else:
min_post_id = 0
max_post_id = 100000
available_ratings = []
if 'rating' in sample_df.columns:
ratings_set = set()
for i in range(min(3, parquet_dataset.num_row_groups)):
sample = parquet_dataset.read_row_group(i, columns=['rating']).to_pandas()
ratings_set.update(sample['rating'].unique())
available_ratings = sorted(list(ratings_set))
else:
available_ratings = ['general']
print(f"Metadata loaded: {num_rows} rows, post_id range: {min_post_id}-{max_post_id}")
return {
'num_rows': num_rows,
'min_post_id': int(min_post_id),
'max_post_id': int(max_post_id),
'available_ratings': available_ratings,
'columns': sample_df.columns.tolist()
}
except Exception as e:
st.error(f"Error loading Parquet metadata: {str(e)}")
return {
'num_rows': 0,
'min_post_id': 0,
'max_post_id': 100000,
'available_ratings': ['general'],
'columns': []
}
def get_filtered_batch(parquet_file, filters, needed_columns, sort_option):
try:
dataset = ds.dataset(parquet_file, format='parquet')
pa_filters = []
for col, op, val in filters:
if col in ['post_id', 'ava_score', 'aesthetic_score']:
if op == '>=':
pa_filters.append(ds.field(col) >= val)
elif op == '<=':
pa_filters.append(ds.field(col) <= val)
elif op == 'in' and len(val) > 0:
rating_filters = [ds.field(col) == r for r in val]
if rating_filters:
or_expr = rating_filters[0]
for rf in rating_filters[1:]:
or_expr = or_expr | rf
pa_filters.append(or_expr)
final_filter = None
if pa_filters:
final_filter = pa_filters[0]
for f in pa_filters[1:]:
final_filter = final_filter & f
scanner = dataset.scanner(columns=needed_columns, filter=final_filter)
df = scanner.to_table().to_pandas()
df.set_index('post_id', inplace=True)
if sort_option == "Post ID (Descending)":
df = df.sort_values(by=df.index.name, ascending=False)
elif sort_option == "Post ID (Ascending)":
df = df.sort_values(by=df.index.name, ascending=True)
elif sort_option == "AVA Score":
df = df.sort_values(by='ava_score', ascending=False)
elif sort_option == "Aesthetic Score":
df = df.sort_values(by='aesthetic_score', ascending=False)
return df
except Exception as e:
st.error(f"Error reading batch: {str(e)}")
return pd.DataFrame()
def process_tags_for_filtering(df, selected_tags, undesired_tags):
if not selected_tags and not undesired_tags:
return df
mask = np.ones(len(df), dtype=bool)
if selected_tags:
for i, tags_list in enumerate(df['tags']):
if mask[i]:
if isinstance(tags_list, list):
tags_set = set(tags_list)
elif isinstance(tags_list, (np.ndarray, np.generic)):
tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set()
elif tags_list:
tags_set = {tags_list}
else:
tags_set = set()
if not selected_tags.issubset(tags_set):
mask[i] = False
if undesired_tags:
for i, tags_list in enumerate(df['tags']):
if mask[i]:
if isinstance(tags_list, list):
tags_set = set(tags_list)
elif isinstance(tags_list, (np.ndarray, np.generic)):
tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set()
elif tags_list:
tags_set = {tags_list}
else:
tags_set = set()
if undesired_tags.intersection(tags_set):
mask[i] = False
return df[mask]
@st.cache_data(ttl=600)
def get_filtered_data(parquet_file, filters_str, sort_option, selected_tags_str, undesired_tags_str, page_number, items_per_page):
filters = eval(filters_str)
selected_tags = set(eval(selected_tags_str))
undesired_tags = set(eval(undesired_tags_str))
needed_columns = ['post_id', 'tags', 'ava_score', 'aesthetic_score', 'rating', 'large_file_url']
df = get_filtered_batch(parquet_file, filters, needed_columns, sort_option)
if selected_tags or undesired_tags:
df = process_tags_for_filtering(df, selected_tags, undesired_tags)
return df
st.title(f'{data_source} Images')
metadata = load_parquet_metadata(parquet_file)
score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.0, 10.0), step=0.1)
score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1)
min_post_id = metadata['min_post_id']
max_post_id = metadata['max_post_id']
post_id_range = st.sidebar.slider('Select Post ID range',
min_value=min_post_id,
max_value=max_post_id,
value=(min_post_id, max_post_id),
step=1000)
available_ratings = metadata['available_ratings']
selected_ratings = st.sidebar.multiselect(
'Select ratings to include',
options=available_ratings,
default=[],
help='Filter images by their rating category'
)
page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1)
items_per_page = 50
sort_option = st.sidebar.selectbox('Sort by', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0)
user_input_tags = st.text_input('Enter tags (space-separated)', value='1girl scenery', help='Filter images based on tags. Use "-" to exclude tags.')
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')])
undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')])
filters = [
('ava_score', '>=', score_range[0]),
('ava_score', '<=', score_range[1]),
('aesthetic_score', '>=', score_range_v2[0]),
('aesthetic_score', '<=', score_range_v2[1]),
('post_id', '>=', post_id_range[0]),
('post_id', '<=', post_id_range[1]),
]
if selected_ratings:
filters.append(('rating', 'in', selected_ratings))
filters_str = repr(filters)
selected_tags_str = repr(list(selected_tags))
undesired_tags_str = repr(list(undesired_tags))
start_time = time.time()
current_batch = get_filtered_data(
parquet_file, filters_str, sort_option,
selected_tags_str, undesired_tags_str,
page_number, items_per_page
)
print(f"Data retrieved in {time.time() - start_time:.2f} seconds")
batch_start = (page_number - 1) * items_per_page
end_idx = min(batch_start + items_per_page, len(current_batch))
current_data = current_batch.iloc[batch_start:end_idx] if batch_start < len(current_batch) else pd.DataFrame()
st.sidebar.write(f"Images on this page: {len(current_data)}")
st.sidebar.write(f"Total filtered sample: {len(current_batch)}")
columns_per_row = 5
rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)]
for row in rows:
cols = st.columns(columns_per_row)
for col, (_, row_data) in zip(cols, row.iterrows()):
with col:
post_id = row_data.name
if data_source == "Danbooru":
link = f"https://danbooru.donmai.us/posts/{post_id}"
elif data_source == "Gelbooru":
link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}"
elif data_source == "Rule 34":
link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}"
st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA: {row_data['ava_score']:.2f}, Aesthetic: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True)
def histogram_slider(df, column1, column2):
if df.empty:
return
sample_size = min(5000, len(df))
if len(df) > sample_size:
step = len(df) // sample_size
indices = np.arange(0, len(df), step)[:sample_size]
sample_data = df.iloc[indices]
else:
sample_data = df
hist1, bin_edges1 = np.histogram(sample_data[column1].dropna(), bins=30)
hist2, bin_edges2 = np.histogram(sample_data[column2].dropna(), bins=30)
fig = go.Figure()
fig.add_trace(go.Bar(
x=(bin_edges1[:-1] + bin_edges1[1:])/2,
y=hist1,
name=column1,
opacity=0.75,
width=(bin_edges1[1]-bin_edges1[0])
))
fig.add_trace(go.Bar(
x=(bin_edges2[:-1] + bin_edges2[1:])/2,
y=hist2,
name=column2,
opacity=0.75,
width=(bin_edges2[1]-bin_edges2[0])
))
fig.update_layout(
barmode='overlay',
bargap=0.1,
height=200,
margin=dict(l=0, r=0, t=0, b=0),
legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5),
)
st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False})
del sample_data, hist1, hist2, bin_edges1, bin_edges2
gc.collect()
if not current_batch.empty:
start_time = time.time()
histogram_slider(current_batch, 'ava_score', 'aesthetic_score')
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds")