Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import clip | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[info] device: {DEVICE}") | |
| print("[info] Loading CLIP ViT-L/14 ...") | |
| clip_model, preprocess = clip.load("ViT-L/14", device=DEVICE) | |
| clip_model.eval() | |
| print("[info] Downloading aesthetic-classifier checkpoint ...") | |
| ckpt_path = hf_hub_download( | |
| repo_id="purplesmartai/aesthetic-classifier", | |
| filename="v2.ckpt", | |
| ) | |
| checkpoint_data = torch.load(ckpt_path, map_location=DEVICE) | |
| state_dict = checkpoint_data["state_dict"] | |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| aesthetic_model = nn.Sequential( | |
| nn.Linear(768, 1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 1), | |
| ).to(DEVICE) | |
| aesthetic_model.load_state_dict(state_dict) | |
| aesthetic_model.eval() | |
| print("[info] Model ready.") | |
| def get_score(image: Image.Image) -> float: | |
| t = preprocess(image.convert("RGB")).unsqueeze(0).to(DEVICE) | |
| feat = clip_model.encode_image(t).cpu().numpy().astype("float32") | |
| norm = np.linalg.norm(feat, axis=1, keepdims=True) | |
| feat = feat / np.where(norm == 0, 1, norm) | |
| return aesthetic_model(torch.tensor(feat, device=DEVICE)).item() | |
| def raw_to_pony(raw: float) -> int: | |
| return int(max(0.0, min(0.99, raw)) * 10) | |
| COLOURS = [ | |
| "#c0392b", "#e74c3c", "#e67e22", "#f39c12", "#d4ac0d", | |
| "#27ae60", "#1e8449", "#148f77", "#0e6655", "#0a4f42", | |
| ] | |
| def build_html(raw: float) -> str: | |
| pony = raw_to_pony(raw) | |
| colour = COLOURS[pony] | |
| # Two rows of 5 so the grid never overflows | |
| rows = [] | |
| for row_start in (0, 5): | |
| cells = "" | |
| for i in range(row_start, row_start + 5): | |
| active = i == pony | |
| bg = COLOURS[i] if active else "rgba(255,255,255,0.07)" | |
| border = f"2px solid {COLOURS[i]}" if active else "2px solid rgba(255,255,255,0.12)" | |
| weight = "700" if active else "400" | |
| scale = "scale(1.08)" if active else "scale(1)" | |
| opac = "1" if active else "0.5" | |
| cells += ( | |
| f'<div style="background:{bg};border:{border};border-radius:8px;' | |
| f'padding:9px 4px;text-align:center;font-size:.78rem;font-weight:{weight};' | |
| f'color:#fff;transform:{scale};opacity:{opac};transition:all .2s;' | |
| f'user-select:none;white-space:nowrap;">' | |
| f"score_{i}</div>" | |
| ) | |
| rows.append( | |
| f'<div style="display:grid;grid-template-columns:repeat(5,1fr);gap:5px;margin-bottom:5px;">' | |
| f"{cells}</div>" | |
| ) | |
| bar_w = min(max(raw, 0.0), 1.0) * 100 | |
| return f""" | |
| <div style="font-family:'Inter',sans-serif;padding:8px 0;"> | |
| <div style="text-align:center;margin-bottom:18px;"> | |
| <div style="display:inline-block;background:{colour};color:#fff;border-radius:12px; | |
| padding:12px 32px;font-size:1.9rem;font-weight:800;letter-spacing:.04em; | |
| box-shadow:0 4px 20px {colour}66;">score_{pony}</div> | |
| <div style="color:#aaa;font-size:.82rem;margin-top:7px;"> | |
| raw: <code style="color:#ddd">{raw:.4f}</code> | |
| </div> | |
| </div> | |
| {"".join(rows)} | |
| <div style="background:rgba(255,255,255,.1);border-radius:6px;height:7px;overflow:hidden;margin-top:8px;"> | |
| <div style="width:{bar_w:.1f}%;height:100%; | |
| background:linear-gradient(90deg,#c0392b,#f39c12,#27ae60); | |
| border-radius:6px;"></div> | |
| </div> | |
| <div style="display:flex;justify-content:space-between;font-size:.7rem;color:#666;margin-top:4px;"> | |
| <span>score_0</span><span>score_9</span> | |
| </div> | |
| </div>""" | |
| def classify(image): | |
| if image is None: | |
| return "<p style='color:#888;text-align:center;padding:40px 0'>Upload an image to score it.</p>" | |
| return build_html(get_score(image)) | |
| with gr.Blocks( | |
| title="Aesthetic Classifier — PurpleSmartAI", | |
| theme=gr.themes.Soft(primary_hue="purple"), | |
| css=".gradio-container{max-width:860px!important;margin:auto}" | |
| " #title{text-align:center} #sub{text-align:center;color:#888;font-size:.9rem;margin-bottom:1.4rem}", | |
| ) as demo: | |
| gr.Markdown("# 🎨 Aesthetic Classifier", elem_id="title") | |
| gr.Markdown( | |
| "CLIP ViT-L/14 regression model by **PurpleSmartAI** for Pony V7 captioning. " | |
| "Outputs a **score_0…score_9** tag used directly in training captions.", | |
| elem_id="sub", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(type="pil", label="Input Image", height=340) | |
| run_btn = gr.Button("✨ Score image", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| out_html = gr.HTML( | |
| value="<p style='color:#888;text-align:center;padding:40px 0'>" | |
| "Upload an image to see its score.</p>", | |
| ) | |
| gr.Markdown( | |
| "---\n**Model:** [`purplesmartai/aesthetic-classifier`]" | |
| "(https://huggingface.co/purplesmartai/aesthetic-classifier)" | |
| " · **Backbone:** OpenAI CLIP ViT-L/14" | |
| ) | |
| run_btn.click(fn=classify, inputs=img_input, outputs=out_html) | |
| img_input.change(fn=classify, inputs=img_input, outputs=out_html) | |
| if __name__ == "__main__": | |
| demo.launch() |