palette-edit-classifier / models /oklab_utils.py
Jonttup's picture
Upload models/oklab_utils.py with huggingface_hub
27fdf85 verified
"""
OKLab Color Space Utilities
Perceptually uniform color space for semantic loss computation.
OKLab ensures that equal distances in the color space correspond to
equal perceived differences β€” critical for meaningful color-based encoding.
Key functions:
- srgb_to_oklab / oklab_to_srgb: Color space conversions
- rotate_ab: Rotate hue in a-b plane (for domain/idiom shifts)
- set_chroma: Set chroma magnitude (for purity encoding)
- OKLabMSELoss: Perceptually uniform loss function
- hsl_to_oklab_batch: Batch conversion for training
"""
import torch
import torch.nn as nn
import math
from typing import Tuple
def clamp(x: float, lo: float, hi: float) -> float:
"""Clamp a value to [lo, hi]."""
return max(lo, min(hi, x))
# ── sRGB ↔ Linear RGB ──
def srgb_to_linear(c: float) -> float:
"""sRGB gamma to linear."""
if c <= 0.04045:
return c / 12.92
return ((c + 0.055) / 1.055) ** 2.4
def linear_to_srgb(c: float) -> float:
"""Linear to sRGB gamma."""
if c <= 0.0031308:
return c * 12.92
return 1.055 * (c ** (1.0 / 2.4)) - 0.055
# ── sRGB ↔ OKLab ──
def srgb_to_oklab(r: float, g: float, b: float) -> Tuple[float, float, float]:
"""Convert sRGB [0,1] to OKLab."""
r_lin = srgb_to_linear(r)
g_lin = srgb_to_linear(g)
b_lin = srgb_to_linear(b)
l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin
l_c = l_ ** (1.0 / 3.0) if l_ >= 0 else -((-l_) ** (1.0 / 3.0))
m_c = m_ ** (1.0 / 3.0) if m_ >= 0 else -((-m_) ** (1.0 / 3.0))
s_c = s_ ** (1.0 / 3.0) if s_ >= 0 else -((-s_) ** (1.0 / 3.0))
L = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
a = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c
return (L, a, b_ok)
def oklab_to_srgb(L: float, a: float, b_ok: float) -> Tuple[float, float, float]:
"""Convert OKLab to sRGB [0,1]."""
l_c = L + 0.3963377774 * a + 0.2158037573 * b_ok
m_c = L - 0.1055613458 * a - 0.0638541728 * b_ok
s_c = L - 0.0894841775 * a - 1.2914855480 * b_ok
l_ = l_c * l_c * l_c
m_ = m_c * m_c * m_c
s_ = s_c * s_c * s_c
r_lin = +4.0767416621 * l_ - 3.3077115913 * m_ + 0.2309699292 * s_
g_lin = -1.2684380046 * l_ + 2.6097574011 * m_ - 0.3413193965 * s_
b_lin = -0.0041960863 * l_ - 0.7034186147 * m_ + 1.7076147010 * s_
r = clamp(linear_to_srgb(clamp(r_lin, 0, 1)), 0, 1)
g = clamp(linear_to_srgb(clamp(g_lin, 0, 1)), 0, 1)
b = clamp(linear_to_srgb(clamp(b_lin, 0, 1)), 0, 1)
return (r, g, b)
# ── HSL ↔ RGB ──
def hsl_to_rgb(h_deg: float, s_pct: float, l_pct: float) -> Tuple[float, float, float]:
"""Convert HSL (degrees, percent, percent) to RGB [0,1]."""
h = h_deg / 360.0
s = s_pct / 100.0
l = l_pct / 100.0
if s == 0:
return (l, l, l)
def hue_to_rgb(p, q, t):
if t < 0: t += 1
if t > 1: t -= 1
if t < 1/6: return p + (q - p) * 6 * t
if t < 1/2: return q
if t < 2/3: return p + (q - p) * (2/3 - t) * 6
return p
q = l * (1 + s) if l < 0.5 else l + s - l * s
p = 2 * l - q
r = hue_to_rgb(p, q, h + 1/3)
g = hue_to_rgb(p, q, h)
b = hue_to_rgb(p, q, h - 1/3)
return (r, g, b)
def rgb_to_hsl(r: float, g: float, b: float) -> Tuple[float, float, float]:
"""Convert RGB [0,1] to HSL (degrees, percent, percent)."""
max_c = max(r, g, b)
min_c = min(r, g, b)
l = (max_c + min_c) / 2.0
if max_c == min_c:
h = s = 0.0
else:
d = max_c - min_c
s = d / (2.0 - max_c - min_c) if l > 0.5 else d / (max_c + min_c)
if max_c == r:
h = (g - b) / d + (6 if g < b else 0)
elif max_c == g:
h = (b - r) / d + 2
else:
h = (r - g) / d + 4
h /= 6.0
return (h * 360.0, s * 100.0, l * 100.0)
# ── OKLab Operations ──
def rotate_ab(a: float, b: float, degrees: float) -> Tuple[float, float]:
"""Rotate hue in OKLab a-b plane by given degrees."""
rad = math.radians(degrees)
cos_r = math.cos(rad)
sin_r = math.sin(rad)
return (a * cos_r - b * sin_r, a * sin_r + b * cos_r)
def set_chroma(a: float, b: float, target_c: float) -> Tuple[float, float]:
"""Set the chroma (magnitude in a-b plane) to target value."""
current_c = math.sqrt(a * a + b * b)
if current_c < 1e-10:
return (target_c, 0.0) # Default direction
scale = target_c / current_c
return (a * scale, b * scale)
def get_chroma(a: float, b: float) -> float:
"""Get chroma magnitude from a-b values."""
return math.sqrt(a * a + b * b)
def compute_delta_e_oklab(
L1: float, a1: float, b1: float,
L2: float, a2: float, b2: float,
) -> float:
"""Compute Ξ”E in OKLab space (perceptual color difference)."""
return math.sqrt((L1 - L2) ** 2 + (a1 - a2) ** 2 + (b1 - b2) ** 2)
# ── Batch Operations (PyTorch) ──
def hsl_to_oklab_batch(hsl: torch.Tensor) -> torch.Tensor:
"""
Batch convert HSL [0,1] normalized to OKLab.
Args:
hsl: (..., 3) tensor with H,S,L in [0,1]
Returns:
(..., 3) tensor with L,a,b in OKLab
"""
h = hsl[..., 0] * 360.0 # Back to degrees
s = hsl[..., 1] * 100.0 # Back to percent
l = hsl[..., 2] * 100.0 # Back to percent
# HSL to RGB (vectorized)
h_norm = h / 360.0
q = torch.where(l / 100.0 < 0.5,
(l / 100.0) * (1 + s / 100.0),
(l / 100.0) + (s / 100.0) - (l / 100.0) * (s / 100.0))
p = 2 * (l / 100.0) - q
def hue2rgb(p, q, t):
t = t % 1.0
r = torch.where(t < 1/6, p + (q - p) * 6 * t,
torch.where(t < 1/2, q,
torch.where(t < 2/3, p + (q - p) * (2/3 - t) * 6, p)))
return r
r = hue2rgb(p, q, h_norm + 1/3)
g = hue2rgb(p, q, h_norm)
b = hue2rgb(p, q, h_norm - 1/3)
# Handle achromatic (s == 0)
achromatic = (s < 0.001)
r = torch.where(achromatic, l / 100.0, r)
g = torch.where(achromatic, l / 100.0, g)
b = torch.where(achromatic, l / 100.0, b)
# sRGB to linear
r_lin = torch.where(r <= 0.04045, r / 12.92, ((r + 0.055) / 1.055) ** 2.4)
g_lin = torch.where(g <= 0.04045, g / 12.92, ((g + 0.055) / 1.055) ** 2.4)
b_lin = torch.where(b <= 0.04045, b / 12.92, ((b + 0.055) / 1.055) ** 2.4)
# Linear RGB to OKLab
l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin
l_c = torch.sign(l_) * torch.abs(l_).pow(1/3)
m_c = torch.sign(m_) * torch.abs(m_).pow(1/3)
s_c = torch.sign(s_) * torch.abs(s_).pow(1/3)
L_ok = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
a_ok = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c
return torch.stack([L_ok, a_ok, b_ok], dim=-1)
def denormalize_hsl(hsl_norm: torch.Tensor) -> torch.Tensor:
"""Convert normalized HSL [0,1] to degrees/percent format."""
result = hsl_norm.clone()
result[..., 0] *= 360.0 # H: [0,1] β†’ [0,360]
result[..., 1] *= 100.0 # S: [0,1] β†’ [0,100]
result[..., 2] *= 100.0 # L: [0,1] β†’ [0,100]
return result
class OKLabMSELoss(nn.Module):
"""
Perceptually uniform loss in OKLab space.
Converts predicted and target HSL values to OKLab, then computes MSE.
This handles hue circularity correctly (359Β° β‰ˆ 1Β°) because OKLab
represents hue as a-b coordinates, not an angle.
"""
def __init__(self):
super().__init__()
def forward(
self,
pred_hsl: torch.Tensor, # (B, 3) predicted HSL in [0,1]
target_hsl: torch.Tensor, # (B, 3) target HSL in [0,1]
) -> torch.Tensor:
"""Compute perceptually uniform loss."""
pred_oklab = hsl_to_oklab_batch(pred_hsl)
target_oklab = hsl_to_oklab_batch(target_hsl)
return torch.nn.functional.mse_loss(pred_oklab, target_oklab)