YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
- ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- ββ Data Cleaning (must match training) ββββββββββββββββββββββββββββββββββββββ
- ββ Model Definition (MUST MATCH SAVED CHECKPOINT - BiLSTM + Attention) βββββββ
- ββ Load Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- ββ Load Training Data for Recommendations ββββββββββββββββββββββββββββββββββ
- ββ Color Space Conversions ββββββββββββββββββββββββββββββββββββββββββββββββββ
- ββ Prediction Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ
- ββ Gradio Interface Functions βββββββββββββββββββββββββββββββββββββββββββββββ
- ββ Build Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββ
#!/usr/bin/env python3 """ Gradio Web UI for Color Name to RGB Predictor. Loads saved FastText + LSTM models and provides interactive color prediction. """
import os import warnings import re import numpy as np import torch import torch.nn as nn import gradio as gr from gensim.models import FastText from colorsys import rgb_to_hsv, rgb_to_hls
warnings.filterwarnings('ignore')
ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
FT_MODEL_PATH = 'best_fasttext_model.ft' LSTM_MODEL_PATH = 'best_color_model.pt' # BiLSTM+Attention matches FastText 100-dim VEC_SIZE = 100 MAX_TOKENS = 4 HIDDEN_SIZE = 256 BIDIRECTIONAL = True DROPOUT = 0.3 NUM_LAYERS = 2
ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}")
ββ Data Cleaning (must match training) ββββββββββββββββββββββββββββββββββββββ
def clean_color_name(name: str) -> str: name = name.lower().strip() typo_fixes = { 'violence': 'violet', 'greylightblue': 'grey light blue', 'greem': 'green', 'grenn': 'green', 'mocca': 'mocha', 'radish': 'reddish', 'greenerer': 'greener', 'marroon': 'maroon', 'vert': 'green', 'techelet': 'teal', 'majenta': 'magenta', 'magink': 'magenta pink', 'orangegray': 'orange gray', 'yellowbrowngreen': 'yellow brown green', 'fungal growth': 'fungal green', } for typo, fix in typo_fixes.items(): name = re.sub(rf'\b{re.escape(typo)}\b', fix, name) name = re.sub(r'[^\w\s-]', '', name) name = re.sub(r'\s+', ' ', name).strip() return name
ββ Model Definition (MUST MATCH SAVED CHECKPOINT - BiLSTM + Attention) βββββββ
class ColorBiLSTM(nn.Module): """BiLSTM + Attention architecture from test_best_color_model.pt""" def init( self, input_size=VEC_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, bidirectional=BIDIRECTIONAL, dropout=DROPOUT, output_size=3 ): super().init() self.hidden_size = hidden_size self.bidirectional = bidirectional self.num_directions = 2 if bidirectional else 1
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0
)
lstm_out_size = hidden_size * self.num_directions
self.attention = nn.Sequential(
nn.Linear(lstm_out_size, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.fc1 = nn.Linear(lstm_out_size, 256)
self.bn1 = nn.BatchNorm1d(256)
self.dropout1 = nn.Dropout(dropout)
self.fc2 = nn.Linear(256, 128)
self.bn2 = nn.BatchNorm1d(128)
self.dropout2 = nn.Dropout(dropout)
self.fc3 = nn.Linear(128, output_size)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()
def forward(self, x):
lstm_out, _ = self.lstm(x)
attn_weights = self.attention(lstm_out)
attn_weights = torch.softmax(attn_weights, dim=1)
context = torch.sum(attn_weights * lstm_out, dim=1)
x = self.relu(self.bn1(self.fc1(context)))
x = self.dropout1(x)
x = self.relu(self.bn2(self.fc2(x)))
x = self.dropout2(x)
x = self.sigmoid(self.fc3(x))
return x
ββ Load Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Loading FastText model...") fasttext = FastText.load(FT_MODEL_PATH) print(f"FastText vocab size: {len(fasttext.wv)}")
print("Loading BiLSTM model...") model = ColorBiLSTM().to(device) checkpoint = torch.load(LSTM_MODEL_PATH, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Model loaded from epoch {checkpoint.get('epoch', 'unknown')}")
ββ Load Training Data for Recommendations ββββββββββββββββββββββββββββββββββ
def get_training_embeddings_sample(n=100000): """Load a sample of training data for recommendations.""" import pandas as pd from sklearn.model_selection import train_test_split
data = pd.read_csv('xkcd_scaled_data_Final.txt', nrows=n)
data = data.dropna(subset=['name']).reset_index(drop=True)
data['name_clean'] = data['name'].apply(clean_color_name)
data = data[data['name_clean'].str.len() > 0].reset_index(drop=True)
indices = np.arange(len(data))
train_idx, _ = train_test_split(indices, test_size=0.25, random_state=42)
train_names = data.loc[train_idx, 'name_clean'].tolist()
train_rgbs = data.loc[train_idx, ['red', 'green', 'blue']].values
# Build embeddings
X = np.zeros((len(train_names), MAX_TOKENS, VEC_SIZE), dtype=np.float32)
for i, name in enumerate(train_names):
tokens = name.split()
for j, token in enumerate(tokens[:MAX_TOKENS]):
X[i, j] = fasttext.wv[token]
return X, train_names, train_rgbs
print("Loading training data for recommendations...") X_train, train_names, train_rgbs = get_training_embeddings_sample(100000) print(f"Loaded {len(train_names)} training samples for recommendations")
ββ Color Space Conversions ββββββββββββββββββββββββββββββββββββββββββββββββββ
def rgb_to_hex(r, g, b): return f"#{r:02x}{g:02x}{b:02x}"
def rgb_to_lab(r, g, b): """Convert RGB (0-255) to CIE LAB.""" r, g, b = r / 255.0, g / 255.0, b / 255.0
def to_linear(c):
if c <= 0.04045:
return c / 12.92
return ((c + 0.055) / 1.055) ** 2.4
r_lin, g_lin, b_lin = to_linear(r), to_linear(g), to_linear(b)
x = r_lin * 0.4124564 + g_lin * 0.3575761 + b_lin * 0.1804375
y = r_lin * 0.2126729 + g_lin * 0.7151522 + b_lin * 0.0721750
z = r_lin * 0.0193339 + g_lin * 0.1191920 + b_lin * 0.9503041
x_ref, y_ref, z_ref = 0.95047, 1.00000, 1.08883
def f(t):
if t > 0.008856:
return t ** (1/3)
return 7.787 * t + 16/116
fx, fy, fz = f(x/x_ref), f(y/y_ref), f(z/z_ref)
L = 116 * fy - 16
a = 500 * (fx - fy)
b_lab = 200 * (fy - fz)
return L, a, b_lab
def rgb_to_hsv_values(r, g, b): h, s, v = rgb_to_hsv(r/255.0, g/255.0, b/255.0) return h * 360, s * 100, v * 100
def rgb_to_hsl_values(r, g, b): h, l, s = rgb_to_hls(r/255.0, g/255.0, b/255.0) return h * 360, s * 100, l * 100
ββ Prediction Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_embedding(name): name = clean_color_name(name) tokens = name.split() x = np.zeros((MAX_TOKENS, VEC_SIZE), dtype=np.float32) for j, token in enumerate(tokens[:MAX_TOKENS]): x[j] = fasttext.wv[token] return x
def predict_color(name): """Predict RGB from color name.""" x_test = get_embedding(name) with torch.no_grad(): x_tensor = torch.from_numpy(x_test).unsqueeze(0).to(device) pred = model(x_tensor).cpu().numpy()[0] r, g, b = int(pred[0] * 255), int(pred[1] * 255), int(pred[2] * 255) return r, g, b
def cosine_similarity(query_vec, reference_vecs=X_train): m, n = query_vec.shape test_new = query_vec.reshape(m * n) word_mag = np.linalg.norm(test_new) if word_mag == 0: return np.arange(len(reference_vecs))
p, q, r = reference_vecs.shape
ref_new = reference_vecs.reshape(p, q * r)
dotted = np.dot(test_new, ref_new.T)
mags = np.linalg.norm(ref_new, axis=1)
cosine = dotted / (word_mag * mags + 1e-8)
return cosine.argsort()[::-1]
def get_recommendations(query_embedding, top_k=5): """Get top-k similar colors from training data.""" indices = cosine_similarity(query_embedding) results = [] for idx in indices[:top_k]: name = train_names[idx] rgb = train_rgbs[idx] hex_code = rgb_to_hex(*rgb) results.append((name, hex_code, rgb)) return results
ββ Gradio Interface Functions βββββββββββββββββββββββββββββββββββββββββββββββ
def process_color_name(color_name): if not color_name or not color_name.strip(): return None, "", "", "", "", "", ""
color_name = color_name.strip()
# Predict
r, g, b = predict_color(color_name)
hex_code = rgb_to_hex(r, g, b)
# Color patch HTML
color_patch = f"""
<div style="
width: 120px;
height: 120px;
background-color: {hex_code};
border: 2px solid #ddd;
border-radius: 10px;
display: inline-block;
"></div>
"""
# Recommendations
query_emb = get_embedding(color_name)
recommendations = get_recommendations(query_emb, top_k=5)
rec_html = '<div style="display: flex; flex-direction: column; gap: 8px;">'
for name, rec_hex, rec_rgb in recommendations:
rec_html += f"""
<div style="display: flex; align-items: center; padding: 10px; background: #f9f9f9; border-radius: 8px; border: 1px solid #eee;">
<div style="
width: 50px;
height: 50px;
background-color: {rec_hex};
border: 1px solid #ddd;
border-radius: 6px;
margin-right: 16px;
flex-shrink: 0;
"></div>
<div>
<div style="font-weight: 600; font-size: 14px;">{name}</div>
<div style="font-family: monospace; color: #666; font-size: 13px;">{rec_hex.upper()}</div>
</div>
</div>
"""
rec_html += '</div>'
# Color space conversions
L, a, b_lab = rgb_to_lab(r, g, b)
h_hsv, s_hsv, v_hsv = rgb_to_hsv_values(r, g, b)
h_hsl, s_hsl, l_hsl = rgb_to_hsl_values(r, g, b)
lab_str = f"L* = {L:.2f}, a* = {a:.2f}, b* = {b_lab:.2f}"
hsv_str = f"H = {h_hsv:.1f}Β°, S = {s_hsv:.1f}%, V = {v_hsv:.1f}%"
hsl_str = f"H = {h_hsl:.1f}Β°, S = {s_hsl:.1f}%, L = {l_hsl:.1f}%"
rgb_str = f"RGB: ({r}, {g}, {b})"
hex_str = f"HEX: {hex_code.upper()}"
return color_patch, rgb_str, hex_str, rec_html, lab_str, hsv_str, hsl_str
ββ Build Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββ
css = """ .gradio-container { max-width: 900px !important; margin: 0 auto; } """
with gr.Blocks(css=css, title="Color Name to RGB Predictor") as demo: gr.Markdown(""" # π¨ Color Name to RGB Predictor Enter a color name (e.g., "ocean blue", "blood red", "forest green") and get the predicted RGB color with similar color recommendations and color space conversions. """)
with gr.Row():
with gr.Column(scale=1):
color_input = gr.Textbox(
label="Color Name",
placeholder="e.g., ocean blue, blood red, forest green, sunset orange, lavender",
value="ocean blue"
)
predict_btn = gr.Button("Predict Color", variant="primary", size="lg")
with gr.Column(scale=1):
color_patch = gr.HTML(label="Predicted Color")
with gr.Row():
rgb_output = gr.Textbox(label="RGB Value", interactive=False)
hex_output = gr.Textbox(label="HEX Code", interactive=False)
gr.Markdown("## π― Recommended Colors")
recommendations = gr.HTML(label="Similar Colors")
gr.Markdown("## π Color Space Values")
with gr.Row():
with gr.Column():
lab_output = gr.Textbox(label="CIE LAB", interactive=False)
with gr.Column():
hsv_output = gr.Textbox(label="HSV", interactive=False)
with gr.Column():
hsl_output = gr.Textbox(label="HSL", interactive=False)
# Examples
gr.Examples(
examples=[
["ocean blue"],
["blood red"],
["forest green"],
["sunset orange"],
["lavender"],
["dark reddish brown"],
["bright neon pink"],
["pale yellow"],
["deep purple"],
["muddy green"],
["electric blue"],
["warm gray"],
],
inputs=color_input,
label="Try these examples"
)
# Citation
gr.Markdown("""
---
### π Citation
If you use this work in your research, please cite:
```bibtex
@article{jyothi2023text2color,
title={Text2Color Networks: Deep Learning Models for Color Generation from Compositional Color Descriptions},
author={Jyothi, Kondalarao and Okade, Manish},
journal={International Journal on Artificial Intelligence Tools},
volume={32},
number={06},
pages={2350026},
year={2023},
publisher={World Scientific}
}
```
""")
predict_btn.click(
fn=process_color_name,
inputs=color_input,
outputs=[color_patch, rgb_output, hex_output, recommendations, lab_output, hsv_output, hsl_output]
)
color_input.submit(
fn=process_color_name,
inputs=color_input,
outputs=[color_patch, rgb_output, hex_output, recommendations, lab_output, hsv_output, hsl_output]
)
if name == "main": demo.launch(server_name="0.0.0.0", server_port=7861, share=False)