| |
| |
| |
| |
| |
| |
| |
|
|
| import streamlit as st |
| import json |
| import os |
| import numpy as np |
| from PIL import Image |
| from io import BytesIO |
| import requests |
| from collections import Counter |
| from skimage import color as skcolor |
| from sklearn.cluster import KMeans |
| from openai import OpenAI |
| import copy |
|
|
| |
| |
| |
| FEEDBACK_FILE = "feedback.json" |
|
|
| def load_feedback(): |
| if os.path.exists(FEEDBACK_FILE): |
| with open(FEEDBACK_FILE, "r") as f: |
| return json.load(f) |
| return {"ratings": []} |
|
|
| def save_feedback(data): |
| with open(FEEDBACK_FILE, "w") as f: |
| json.dump(data, f, indent=2) |
|
|
| |
| |
| |
| class DominantColorDetector: |
| MIN_PIXEL_FRACTION = 0.01 |
| NEAR_BLACK_L = 15 |
| NEAR_WHITE_L = 90 |
| NEAR_GRAY_CHROMA = 8 |
| DUPLICATE_DELTA_E = 12.0 |
|
|
| def __init__(self, num_colors=12, resize_dim=200, |
| weights=None, edge_multiplier=2.5): |
| self.num_colors = num_colors |
| self.resize_dim = resize_dim |
| |
| self.weights = weights or {"pixel_fraction": 0.35, |
| "chroma": 0.5, |
| "lightness": 0.15} |
| self.edge_multiplier = edge_multiplier |
|
|
| def _preprocess(self, img: Image.Image): |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| original_w, original_h = img.size |
| aspect = original_h / original_w |
| new_h = max(1, int(self.resize_dim * aspect)) |
| img = img.resize((self.resize_dim, new_h), Image.LANCZOS) |
| resized_w, resized_h = img.size |
| pixels = np.array(img).reshape(-1, 3).astype(np.float32) |
| return pixels, (resized_w, resized_h) |
|
|
| def _build_edge_mask(self, total_pixels, resized_dims): |
| w, h = resized_dims |
| mask = np.ones(total_pixels, dtype=np.float32) |
| |
| edge_frac = 0.10 |
| top_height = int(h * edge_frac) |
| bottom_height = int(h * edge_frac) |
| left_width = int(w * edge_frac) |
| right_width = int(w * edge_frac) |
| |
| for y in range(h): |
| for x in range(w): |
| idx = y * w + x |
| if ( |
| y < top_height |
| or y >= h - bottom_height |
| or x < left_width |
| or x >= w - right_width |
| ): |
| mask[idx] = self.edge_multiplier |
| return mask |
|
|
| def _pixels_to_lab(self, pixels_rgb): |
| """RGB to CIELAB.""" |
| normalized = (pixels_rgb / 255.0).reshape(1, -1, 3) |
| lab = skcolor.rgb2lab(normalized).reshape(-1, 3) |
| return lab.astype(np.float32) |
|
|
| def _lab_to_rgb(self, lab_array): |
| """LAB to uint8 RGB.""" |
| rgb = skcolor.lab2rgb(lab_array.reshape(1, -1, 3)).reshape(-1, 3) |
| return np.clip(rgb * 255, 0, 255).astype(int) |
|
|
| def _is_near_black(self, lab): |
| return lab[0] < self.NEAR_BLACK_L |
|
|
| def _is_near_white(self, lab): |
| return lab[0] > self.NEAR_WHITE_L and lab_chroma(lab) < self.NEAR_GRAY_CHROMA |
|
|
| def _is_near_gray(self, lab): |
| return lab_chroma(lab) < self.NEAR_GRAY_CHROMA |
|
|
| def _remove_near_duplicates(self, color_list): |
| kept = [] |
| for c in color_list: |
| if not any(delta_e_cie76(c["lab"], k["lab"]) < self.DUPLICATE_DELTA_E |
| for k in kept): |
| kept.append(c) |
| return kept |
|
|
| def _score_color(self, lab, pixel_fraction): |
| """Weighted perceptual importance score.""" |
| chroma = lab_chroma(lab) |
| chroma_score = min(chroma / 60.0, 1.0) |
| L = lab[0] |
| lightness_score = max(1.0 - abs(L - 50.0) / 50.0, 0.0) |
| return (pixel_fraction * self.weights["pixel_fraction"] + |
| chroma_score * self.weights["chroma"] + |
| lightness_score * self.weights["lightness"]) |
|
|
| def _build_adaptive_palette(self, sorted_colors): |
| """Same as original – pick background and text colors.""" |
| if not sorted_colors: |
| return None |
| bg = None |
| for c in sorted_colors: |
| if not c["isNearBlack"] and not c["isNearWhite"]: |
| bg = c |
| break |
| if bg is None: |
| near_blacks = [c for c in sorted_colors if c["isNearBlack"]] |
| bg = near_blacks[0] if near_blacks else sorted_colors[0] |
| bg_rgb = (bg["rgb"]["red"], bg["rgb"]["green"], bg["rgb"]["blue"]) |
| white_cr = contrast_ratio(bg_rgb, (255, 255, 255)) |
| black_cr = contrast_ratio(bg_rgb, (0, 0, 0)) |
| if white_cr >= 4.5: |
| primary_rgb = (255, 255, 255) |
| primary_cr = white_cr |
| elif black_cr >= 4.5: |
| primary_rgb = (0, 0, 0) |
| primary_cr = black_cr |
| else: |
| if white_cr >= black_cr: |
| primary_rgb = (255, 255, 255) |
| primary_cr = white_cr |
| else: |
| primary_rgb = (0, 0, 0) |
| primary_cr = black_cr |
| secondary_rgb = None |
| for c in sorted_colors[1:]: |
| candidate = (c["rgb"]["red"], c["rgb"]["green"], c["rgb"]["blue"]) |
| if contrast_ratio(bg_rgb, candidate) >= 3.0: |
| secondary_rgb = candidate |
| break |
| if secondary_rgb is None: |
| alpha = 0.70 |
| secondary_rgb = tuple(int(primary_rgb[i] * alpha + bg_rgb[i] * (1 - alpha)) |
| for i in range(3)) |
| secondary_cr = contrast_ratio(bg_rgb, secondary_rgb) |
| return { |
| "background": { |
| "color": rgb_to_hex(bg_rgb), |
| "rgb": bg["rgb"], |
| }, |
| "primaryText": { |
| "color": rgb_to_hex(primary_rgb), |
| "rgb": {"red": primary_rgb[0], "green": primary_rgb[1], "blue": primary_rgb[2]}, |
| "contrastRatio": round(primary_cr, 2), |
| "meetsWCAG_AA": primary_cr >= 4.5, |
| }, |
| "secondaryText": { |
| "color": rgb_to_hex(secondary_rgb), |
| "rgb": {"red": secondary_rgb[0], "green": secondary_rgb[1], "blue": secondary_rgb[2]}, |
| "contrastRatio": round(secondary_cr, 2), |
| "meetsWCAG_AA": secondary_cr >= 4.5, |
| }, |
| } |
|
|
| def detect_properties(self, img: Image.Image, include_palette=True): |
| pixels_rgb, (resized_w, resized_h) = self._preprocess(img) |
| total_pixels = len(pixels_rgb) |
|
|
| |
| edge_mask = self._build_edge_mask(total_pixels, (resized_w, resized_h)) |
|
|
| max_pixels = 10000 |
| if total_pixels > max_pixels: |
| probs = edge_mask / edge_mask.sum() |
| idx = np.random.choice(total_pixels, max_pixels, replace=False, p=probs) |
| sampled_pixels = pixels_rgb[idx] |
| sampled_mask = edge_mask[idx] |
| else: |
| sampled_pixels = pixels_rgb |
| sampled_mask = edge_mask |
|
|
| pixels_lab = self._pixels_to_lab(sampled_pixels) |
| kmeans = KMeans(n_clusters=self.num_colors, random_state=42, |
| n_init=3, max_iter=100, algorithm="elkan") |
| kmeans.fit(pixels_lab) |
| centroids_lab = kmeans.cluster_centers_ |
| labels = kmeans.labels_ |
| label_counts = Counter(labels) |
| total = len(labels) |
| centroids_rgb = self._lab_to_rgb(centroids_lab) |
|
|
| color_list = [] |
| for i in range(self.num_colors): |
| lab = centroids_lab[i] |
| rgb = centroids_rgb[i] |
| pf = label_counts[i] / total |
| if pf < self.MIN_PIXEL_FRACTION: |
| continue |
| cluster_mask = labels == i |
| edge_frac = np.mean(sampled_mask[cluster_mask] > 1.0) |
| r, g, b = int(rgb[0]), int(rgb[1]), int(rgb[2]) |
| color_list.append({ |
| "lab": lab, |
| "rgb": {"red": r, "green": g, "blue": b}, |
| "color": rgb_to_hex((r, g, b)), |
| "pixelFraction": float(pf), |
| "score": self._score_color(lab, pf), |
| "chroma": round(lab_chroma(lab), 2), |
| "isNearBlack": bool(self._is_near_black(lab)), |
| "isNearWhite": bool(self._is_near_white(lab)), |
| "isNearGray": bool(self._is_near_gray(lab)), |
| "edgeInfluence": round(edge_frac, 3), |
| }) |
|
|
| color_list.sort(key=lambda x: x["score"], reverse=True) |
| color_list = self._remove_near_duplicates(color_list) |
|
|
| dominant_colors = [ |
| {k: v for k, v in c.items() if k != "lab"} |
| for c in color_list |
| ] |
| for c in dominant_colors: |
| c["score"] = round(c["score"], 4) |
| c["pixelFraction"] = round(c["pixelFraction"], 4) |
|
|
| result = { |
| "imagePropertiesAnnotation": { |
| "dominantColors": {"colors": dominant_colors} |
| } |
| } |
| if include_palette: |
| result["suggestedPalette"] = self._build_adaptive_palette(color_list) |
| return result |
|
|
| |
| |
| |
| def lab_chroma(lab): |
| return float(np.sqrt(lab[1]**2 + lab[2]**2)) |
|
|
| def delta_e_cie76(lab1, lab2): |
| return float(np.sqrt(np.sum((np.array(lab1) - np.array(lab2))**2))) |
|
|
| def relative_luminance(rgb_tuple): |
| def linearize(c): |
| c = c / 255.0 |
| return c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 |
| r, g, b = rgb_tuple |
| return 0.2126 * linearize(r) + 0.7152 * linearize(g) + 0.0722 * linearize(b) |
|
|
| def contrast_ratio(rgb1, rgb2): |
| l1 = relative_luminance(rgb1) |
| l2 = relative_luminance(rgb2) |
| lighter, darker = max(l1, l2), min(l1, l2) |
| return (lighter + 0.05) / (darker + 0.05) |
|
|
| def rgb_to_hex(rgb_tuple): |
| r, g, b = int(rgb_tuple[0]), int(rgb_tuple[1]), int(rgb_tuple[2]) |
| return f"#{r:02x}{g:02x}{b:02x}" |
|
|
| |
| |
| |
| st.set_page_config(page_title="Apple Music‑like Palette Tuner", layout="wide") |
| st.title("🎨 Dominant Color Learner (Apple Music Style)") |
|
|
| |
| with st.sidebar: |
| st.header("⚙️ Detector Settings") |
| num_colors = st.slider("Max clusters", 6, 24, 12) |
| resize_dim = st.number_input("Resize dimension", 100, 500, 200) |
|
|
| st.header("🧠 AI Preferences") |
| api_key = st.text_input("OpenAI API key (e.g., OpenRouter)", |
| type="password") |
| api_base = st.text_input("API base URL", |
| value="https://openrouter.ai/api/v1") |
| model_name = st.text_input("Model name", value="openai/gpt-4o") |
|
|
| with st.expander("⚖️ Current Scoring Weights"): |
| w = st.session_state.setdefault("weights", |
| {"pixel_fraction": 0.35, |
| "chroma": 0.5, |
| "lightness": 0.15}) |
| w["pixel_fraction"] = st.number_input("Pixel fraction weight", 0.0, 1.0, w["pixel_fraction"], 0.05) |
| w["chroma"] = st.number_input("Chroma weight", 0.0, 1.0, w["chroma"], 0.05) |
| w["lightness"] = st.number_input("Lightness weight", 0.0, 1.0, w["lightness"], 0.05) |
| edge_multiplier = st.number_input("Edge bonus multiplier", 1.0, 5.0, |
| st.session_state.get("edge_multiplier", 2.5), 0.1) |
| st.session_state.edge_multiplier = edge_multiplier |
|
|
| if st.button("📤 Export improved weights"): |
| export_data = { |
| "weights": st.session_state.weights, |
| "edge_multiplier": st.session_state.edge_multiplier |
| } |
| st.download_button( |
| label="Download weights.json", |
| data=json.dumps(export_data, indent=2), |
| file_name="dominant_color_weights.json", |
| mime="application/json" |
| ) |
|
|
| |
| st.subheader("📌 Enter Spotify / Apple Music cover art URLs") |
| urls_text = st.text_area("One URL per line", height=150) |
| urls = [u.strip() for u in urls_text.splitlines() if u.strip()] |
|
|
| if st.button("🔍 Analyze covers", type="primary"): |
| st.session_state.analysis_results = [] |
| progress = st.progress(0) |
| for i, url in enumerate(urls): |
| try: |
| resp = requests.get(url, timeout=10) |
| resp.raise_for_status() |
| img = Image.open(BytesIO(resp.content)) |
| except Exception as e: |
| st.error(f"Failed to load {url}: {e}") |
| continue |
| detector = DominantColorDetector( |
| num_colors=num_colors, |
| resize_dim=resize_dim, |
| weights=st.session_state.weights, |
| edge_multiplier=st.session_state.edge_multiplier |
| ) |
| result = detector.detect_properties(img, include_palette=True) |
| dominant = result["imagePropertiesAnnotation"]["dominantColors"]["colors"] |
| palette = result.get("suggestedPalette") |
| |
| st.session_state.analysis_results.append({ |
| "url": url, |
| "dominant": dominant, |
| "palette": palette, |
| "img": img, |
| "features": [] |
| }) |
| progress.progress((i+1)/len(urls)) |
| st.success(f"Analyzed {len(st.session_state.analysis_results)} covers.") |
|
|
| |
| if "analysis_results" in st.session_state and st.session_state.analysis_results: |
| feedback_data = load_feedback() |
| for idx, res in enumerate(st.session_state.analysis_results): |
| url = res["url"] |
| img = res["img"] |
| palette = res["palette"] |
| dominant = res["dominant"] |
| with st.container(): |
| col1, col2, col3 = st.columns([2, 3, 3]) |
| with col1: |
| st.image(img, width=200, caption=url[:50]) |
| with col2: |
| st.markdown("**Dominant colors**") |
| if dominant: |
| for c in dominant[:5]: |
| hex_color = c["color"] |
| score = c["score"] |
| st.markdown( |
| f'<span style="display:inline-block;width:20px;height:20px;' |
| f'background:{hex_color};border-radius:4px;margin-right:8px;"></span>' |
| f'{hex_color} – score {score:.3f}', |
| unsafe_allow_html=True |
| ) |
| else: |
| st.write("No dominant colors found.") |
| with col3: |
| if palette: |
| bg_hex = palette["background"]["color"] |
| primary_hex = palette["primaryText"]["color"] |
| st.markdown(f"**Suggested Palette**") |
| st.markdown( |
| f'<div style="background:{bg_hex};padding:12px;border-radius:8px;">' |
| f'<span style="color:{primary_hex};font-weight:bold;">Bg: {bg_hex}</span><br>' |
| f'<span style="color:{primary_hex};">Text: {primary_hex}</span>' |
| f'</div>', |
| unsafe_allow_html=True |
| ) |
|
|
| |
| key_prefix = f"fb_{idx}" |
| colA, colB = st.columns(2) |
| with colA: |
| liked = st.button("👍 Like", key=key_prefix+"_like") |
| with colB: |
| disliked = st.button("👎 Dislike", key=key_prefix+"_dislike") |
|
|
| if liked or disliked: |
| rating = "like" if liked else "dislike" |
| |
| top_color = dominant[0] if dominant else None |
| if top_color: |
| features = { |
| "chroma": top_color["chroma"], |
| "lightness": top_color["isNearBlack"] or top_color["isNearWhite"], |
| "pixelFraction": top_color["pixelFraction"], |
| "edgeInfluence": top_color.get("edgeInfluence", 0.0), |
| "isNearBlack": top_color["isNearBlack"], |
| "isNearWhite": top_color["isNearWhite"], |
| "score": top_color["score"], |
| } |
| else: |
| features = {} |
| feedback_entry = { |
| "url": url, |
| "rating": rating, |
| "top_color_features": features, |
| } |
| feedback_data["ratings"].append(feedback_entry) |
| save_feedback(feedback_data) |
| st.toast(f"Recorded {rating} for the palette.", icon="✅") |
| st.experimental_rerun() |
|
|
| st.markdown("---") |
| col_stats1, col_stats2, col_stats3 = st.columns(3) |
| ratings = feedback_data.get("ratings", []) |
| likes = sum(1 for r in ratings if r["rating"] == "like") |
| dislikes = sum(1 for r in ratings if r["rating"] == "dislike") |
| col_stats1.metric("👍 Likes", likes) |
| col_stats2.metric("👎 Dislikes", dislikes) |
| if st.button("🧠 Learn from my feedback (AI)"): |
| if not api_key: |
| st.error("Please enter your OpenAI API key in the sidebar.") |
| elif len(ratings) < 5: |
| st.warning("Need at least 5 ratings to learn reliably.") |
| else: |
| with st.spinner("Asking AI to analyse your taste..."): |
| |
| prompt = ( |
| "You are an expert in perceptual color psychology and UI design. " |
| "I have a dominant color detector for album cover art that scores colors " |
| "based on three features and an edge bonus. The scoring formula is:\n" |
| "score = pixel_fraction * w_pixel + chroma_score * w_chroma + lightness_score * w_lightness\n" |
| "plus an edge bonus multiplier (currently 2.5) that duplicates edge pixels during sampling.\n" |
| "I want to adjust these weights and the edge multiplier to match my personal taste.\n\n" |
| "Below is my feedback history on extracted palettes. Each entry includes the rating " |
| "(like/dislike) and the features of the top dominant color:\n" |
| ) |
| table = "| URL | Rating | Chroma | Lightness (0=mid, 1=extreme) | PixelFraction | EdgeInfluence | isNearBlack | isNearWhite | Score |\n" |
| table += "|-----|--------|--------|-------------------------------|---------------|---------------|-------------|-------------|-------|\n" |
| for r in ratings[-30:]: |
| f = r.get("top_color_features", {}) |
| table += f"| {r['url'][:30]} | {r['rating']} | {f.get('chroma',0):.2f} | {f.get('lightness',0)} | {f.get('pixelFraction',0):.3f} | {f.get('edgeInfluence',0):.3f} | {f.get('isNearBlack',False)} | {f.get('isNearWhite',False)} | {f.get('score',0):.3f} |\n" |
| prompt += table + "\n\n" |
| prompt += ( |
| "Based on this feedback, please output **only** a JSON object with the updated weights " |
| "and edge multiplier that would better satisfy my likes and avoid my dislikes.\n" |
| "Format exactly:\n" |
| "{\n" |
| ' "weights": {"pixel_fraction": <float>, "chroma": <float>, "lightness": <float>},\n' |
| ' "edge_multiplier": <float>,\n' |
| ' "explanation": "<brief reasoning>"\n' |
| "}\n" |
| ) |
|
|
| try: |
| client = OpenAI(api_key=api_key, base_url=api_base) |
| response = client.chat.completions.create( |
| model=model_name, |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0.2, |
| max_tokens=300, |
| ) |
| content = response.choices[0].message.content.strip() |
| |
| if content.startswith("```"): |
| content = content.split("\n", 1)[1] |
| content = content.rsplit("\n", 1)[0] |
| new_params = json.loads(content) |
| st.session_state.weights = new_params["weights"] |
| st.session_state.edge_multiplier = new_params.get("edge_multiplier", st.session_state.edge_multiplier) |
| st.success("Weights updated based on your feedback!") |
| st.write(f"**AI reasoning:** {new_params.get('explanation', '')}") |
| st.json(new_params["weights"]) |
| except Exception as e: |
| st.error(f"AI learning failed: {str(e)}") |