YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

#!/usr/bin/env python3 """ Gradio Web UI for Color Name to RGB Predictor. Loads saved FastText + LSTM models and provides interactive color prediction. """

import os import warnings import re import numpy as np import torch import torch.nn as nn import gradio as gr from gensim.models import FastText from colorsys import rgb_to_hsv, rgb_to_hls

warnings.filterwarnings('ignore')

── Configuration ────────────────────────────────────────────────────────────

FT_MODEL_PATH = 'best_fasttext_model.ft' LSTM_MODEL_PATH = 'best_color_model.pt' # BiLSTM+Attention matches FastText 100-dim VEC_SIZE = 100 MAX_TOKENS = 4 HIDDEN_SIZE = 256 BIDIRECTIONAL = True DROPOUT = 0.3 NUM_LAYERS = 2

── Device ───────────────────────────────────────────────────────────────────

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}")

── Data Cleaning (must match training) ──────────────────────────────────────

def clean_color_name(name: str) -> str: name = name.lower().strip() typo_fixes = { 'violence': 'violet', 'greylightblue': 'grey light blue', 'greem': 'green', 'grenn': 'green', 'mocca': 'mocha', 'radish': 'reddish', 'greenerer': 'greener', 'marroon': 'maroon', 'vert': 'green', 'techelet': 'teal', 'majenta': 'magenta', 'magink': 'magenta pink', 'orangegray': 'orange gray', 'yellowbrowngreen': 'yellow brown green', 'fungal growth': 'fungal green', } for typo, fix in typo_fixes.items(): name = re.sub(rf'\b{re.escape(typo)}\b', fix, name) name = re.sub(r'[^\w\s-]', '', name) name = re.sub(r'\s+', ' ', name).strip() return name

── Model Definition (MUST MATCH SAVED CHECKPOINT - BiLSTM + Attention) ───────

class ColorBiLSTM(nn.Module): """BiLSTM + Attention architecture from test_best_color_model.pt""" def init( self, input_size=VEC_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, bidirectional=BIDIRECTIONAL, dropout=DROPOUT, output_size=3 ): super().init() self.hidden_size = hidden_size self.bidirectional = bidirectional self.num_directions = 2 if bidirectional else 1

    self.lstm = nn.LSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        batch_first=True,
        bidirectional=bidirectional,
        dropout=dropout if num_layers > 1 else 0
    )
    
    lstm_out_size = hidden_size * self.num_directions
    self.attention = nn.Sequential(
        nn.Linear(lstm_out_size, 64),
        nn.Tanh(),
        nn.Linear(64, 1)
    )
    
    self.fc1 = nn.Linear(lstm_out_size, 256)
    self.bn1 = nn.BatchNorm1d(256)
    self.dropout1 = nn.Dropout(dropout)
    self.fc2 = nn.Linear(256, 128)
    self.bn2 = nn.BatchNorm1d(128)
    self.dropout2 = nn.Dropout(dropout)
    self.fc3 = nn.Linear(128, output_size)
    self.sigmoid = nn.Sigmoid()
    self.relu = nn.ReLU()

def forward(self, x):
    lstm_out, _ = self.lstm(x)
    attn_weights = self.attention(lstm_out)
    attn_weights = torch.softmax(attn_weights, dim=1)
    context = torch.sum(attn_weights * lstm_out, dim=1)
    
    x = self.relu(self.bn1(self.fc1(context)))
    x = self.dropout1(x)
    x = self.relu(self.bn2(self.fc2(x)))
    x = self.dropout2(x)
    x = self.sigmoid(self.fc3(x))
    return x

── Load Models ──────────────────────────────────────────────────────────────

print("Loading FastText model...") fasttext = FastText.load(FT_MODEL_PATH) print(f"FastText vocab size: {len(fasttext.wv)}")

print("Loading BiLSTM model...") model = ColorBiLSTM().to(device) checkpoint = torch.load(LSTM_MODEL_PATH, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Model loaded from epoch {checkpoint.get('epoch', 'unknown')}")

── Load Training Data for Recommendations ──────────────────────────────────

def get_training_embeddings_sample(n=100000): """Load a sample of training data for recommendations.""" import pandas as pd from sklearn.model_selection import train_test_split

data = pd.read_csv('xkcd_scaled_data_Final.txt', nrows=n)
data = data.dropna(subset=['name']).reset_index(drop=True)
data['name_clean'] = data['name'].apply(clean_color_name)
data = data[data['name_clean'].str.len() > 0].reset_index(drop=True)

indices = np.arange(len(data))
train_idx, _ = train_test_split(indices, test_size=0.25, random_state=42)

train_names = data.loc[train_idx, 'name_clean'].tolist()
train_rgbs = data.loc[train_idx, ['red', 'green', 'blue']].values

# Build embeddings
X = np.zeros((len(train_names), MAX_TOKENS, VEC_SIZE), dtype=np.float32)
for i, name in enumerate(train_names):
    tokens = name.split()
    for j, token in enumerate(tokens[:MAX_TOKENS]):
        X[i, j] = fasttext.wv[token]

return X, train_names, train_rgbs

print("Loading training data for recommendations...") X_train, train_names, train_rgbs = get_training_embeddings_sample(100000) print(f"Loaded {len(train_names)} training samples for recommendations")

── Color Space Conversions ──────────────────────────────────────────────────

def rgb_to_hex(r, g, b): return f"#{r:02x}{g:02x}{b:02x}"

def rgb_to_lab(r, g, b): """Convert RGB (0-255) to CIE LAB.""" r, g, b = r / 255.0, g / 255.0, b / 255.0

def to_linear(c):
    if c <= 0.04045:
        return c / 12.92
    return ((c + 0.055) / 1.055) ** 2.4

r_lin, g_lin, b_lin = to_linear(r), to_linear(g), to_linear(b)

x = r_lin * 0.4124564 + g_lin * 0.3575761 + b_lin * 0.1804375
y = r_lin * 0.2126729 + g_lin * 0.7151522 + b_lin * 0.0721750
z = r_lin * 0.0193339 + g_lin * 0.1191920 + b_lin * 0.9503041

x_ref, y_ref, z_ref = 0.95047, 1.00000, 1.08883

def f(t):
    if t > 0.008856:
        return t ** (1/3)
    return 7.787 * t + 16/116

fx, fy, fz = f(x/x_ref), f(y/y_ref), f(z/z_ref)

L = 116 * fy - 16
a = 500 * (fx - fy)
b_lab = 200 * (fy - fz)

return L, a, b_lab

def rgb_to_hsv_values(r, g, b): h, s, v = rgb_to_hsv(r/255.0, g/255.0, b/255.0) return h * 360, s * 100, v * 100

def rgb_to_hsl_values(r, g, b): h, l, s = rgb_to_hls(r/255.0, g/255.0, b/255.0) return h * 360, s * 100, l * 100

── Prediction Functions ─────────────────────────────────────────────────────

def get_embedding(name): name = clean_color_name(name) tokens = name.split() x = np.zeros((MAX_TOKENS, VEC_SIZE), dtype=np.float32) for j, token in enumerate(tokens[:MAX_TOKENS]): x[j] = fasttext.wv[token] return x

def predict_color(name): """Predict RGB from color name.""" x_test = get_embedding(name) with torch.no_grad(): x_tensor = torch.from_numpy(x_test).unsqueeze(0).to(device) pred = model(x_tensor).cpu().numpy()[0] r, g, b = int(pred[0] * 255), int(pred[1] * 255), int(pred[2] * 255) return r, g, b

def cosine_similarity(query_vec, reference_vecs=X_train): m, n = query_vec.shape test_new = query_vec.reshape(m * n) word_mag = np.linalg.norm(test_new) if word_mag == 0: return np.arange(len(reference_vecs))

p, q, r = reference_vecs.shape
ref_new = reference_vecs.reshape(p, q * r)
dotted = np.dot(test_new, ref_new.T)
mags = np.linalg.norm(ref_new, axis=1)
cosine = dotted / (word_mag * mags + 1e-8)
return cosine.argsort()[::-1]

def get_recommendations(query_embedding, top_k=5): """Get top-k similar colors from training data.""" indices = cosine_similarity(query_embedding) results = [] for idx in indices[:top_k]: name = train_names[idx] rgb = train_rgbs[idx] hex_code = rgb_to_hex(*rgb) results.append((name, hex_code, rgb)) return results

── Gradio Interface Functions ───────────────────────────────────────────────

def process_color_name(color_name): if not color_name or not color_name.strip(): return None, "", "", "", "", "", ""

color_name = color_name.strip()

# Predict
r, g, b = predict_color(color_name)
hex_code = rgb_to_hex(r, g, b)

# Color patch HTML
color_patch = f"""
<div style="
    width: 120px; 
    height: 120px; 
    background-color: {hex_code}; 
    border: 2px solid #ddd; 
    border-radius: 10px;
    display: inline-block;
"></div>
"""

# Recommendations
query_emb = get_embedding(color_name)
recommendations = get_recommendations(query_emb, top_k=5)

rec_html = '<div style="display: flex; flex-direction: column; gap: 8px;">'
for name, rec_hex, rec_rgb in recommendations:
    rec_html += f"""
    <div style="display: flex; align-items: center; padding: 10px; background: #f9f9f9; border-radius: 8px; border: 1px solid #eee;">
        <div style="
            width: 50px; 
            height: 50px; 
            background-color: {rec_hex}; 
            border: 1px solid #ddd; 
            border-radius: 6px;
            margin-right: 16px;
            flex-shrink: 0;
        "></div>
        <div>
            <div style="font-weight: 600; font-size: 14px;">{name}</div>
            <div style="font-family: monospace; color: #666; font-size: 13px;">{rec_hex.upper()}</div>
        </div>
    </div>
    """
rec_html += '</div>'

# Color space conversions
L, a, b_lab = rgb_to_lab(r, g, b)
h_hsv, s_hsv, v_hsv = rgb_to_hsv_values(r, g, b)
h_hsl, s_hsl, l_hsl = rgb_to_hsl_values(r, g, b)

lab_str = f"L* = {L:.2f}, a* = {a:.2f}, b* = {b_lab:.2f}"
hsv_str = f"H = {h_hsv:.1f}Β°, S = {s_hsv:.1f}%, V = {v_hsv:.1f}%"
hsl_str = f"H = {h_hsl:.1f}Β°, S = {s_hsl:.1f}%, L = {l_hsl:.1f}%"

rgb_str = f"RGB: ({r}, {g}, {b})"
hex_str = f"HEX: {hex_code.upper()}"

return color_patch, rgb_str, hex_str, rec_html, lab_str, hsv_str, hsl_str

── Build Gradio Interface ───────────────────────────────────────────────────

css = """ .gradio-container { max-width: 900px !important; margin: 0 auto; } """

with gr.Blocks(css=css, title="Color Name to RGB Predictor") as demo: gr.Markdown(""" # 🎨 Color Name to RGB Predictor Enter a color name (e.g., "ocean blue", "blood red", "forest green") and get the predicted RGB color with similar color recommendations and color space conversions. """)

with gr.Row():
    with gr.Column(scale=1):
        color_input = gr.Textbox(
            label="Color Name",
            placeholder="e.g., ocean blue, blood red, forest green, sunset orange, lavender",
            value="ocean blue"
        )
        predict_btn = gr.Button("Predict Color", variant="primary", size="lg")
    
    with gr.Column(scale=1):
        color_patch = gr.HTML(label="Predicted Color")

with gr.Row():
    rgb_output = gr.Textbox(label="RGB Value", interactive=False)
    hex_output = gr.Textbox(label="HEX Code", interactive=False)

gr.Markdown("## 🎯 Recommended Colors")
recommendations = gr.HTML(label="Similar Colors")

gr.Markdown("## πŸ“ Color Space Values")
with gr.Row():
    with gr.Column():
        lab_output = gr.Textbox(label="CIE LAB", interactive=False)
    with gr.Column():
        hsv_output = gr.Textbox(label="HSV", interactive=False)
    with gr.Column():
        hsl_output = gr.Textbox(label="HSL", interactive=False)

# Examples
gr.Examples(
    examples=[
        ["ocean blue"],
        ["blood red"],
        ["forest green"],
        ["sunset orange"],
        ["lavender"],
        ["dark reddish brown"],
        ["bright neon pink"],
        ["pale yellow"],
        ["deep purple"],
        ["muddy green"],
        ["electric blue"],
        ["warm gray"],
    ],
    inputs=color_input,
    label="Try these examples"
)

# Citation
gr.Markdown("""
---
### πŸ“š Citation
If you use this work in your research, please cite:
```bibtex
@article{jyothi2023text2color,
  title={Text2Color Networks: Deep Learning Models for Color Generation from Compositional Color Descriptions},
  author={Jyothi, Kondalarao and Okade, Manish},
  journal={International Journal on Artificial Intelligence Tools},
  volume={32},
  number={06},
  pages={2350026},
  year={2023},
  publisher={World Scientific}
}
```
""")

predict_btn.click(
    fn=process_color_name,
    inputs=color_input,
    outputs=[color_patch, rgb_output, hex_output, recommendations, lab_output, hsv_output, hsl_output]
)

color_input.submit(
    fn=process_color_name,
    inputs=color_input,
    outputs=[color_patch, rgb_output, hex_output, recommendations, lab_output, hsv_output, hsl_output]
)

if name == "main": demo.launch(server_name="0.0.0.0", server_port=7861, share=False)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support