RedFish / app.py
hongyu12321's picture
Update app.py
6e327e0 verified
raw
history blame
11.7 kB
# app.py — Age-first + FAST group cartoons (SD-Turbo), single page (HF Spaces safe)
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import math
from typing import Optional
import gradio as gr
from PIL import Image, ImageDraw
import numpy as np
import torch
# ------------------ Age estimator ------------------
from transformers import AutoImageProcessor, AutoModelForImageClassification
HF_MODEL_ID = "nateraw/vit-age-classifier"
AGE_RANGE_TO_MID = {
"0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
"40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
}
class PretrainedAgeEstimator:
def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
self.model = AutoModelForImageClassification.from_pretrained(model_id)
self.model.to(self.device).eval()
self.id2label = self.model.config.id2label
@torch.inference_mode()
def predict(self, img: Image.Image, topk: int = 5):
if img.mode != "RGB":
img = img.convert("RGB")
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
logits = self.model(**inputs).logits
probs = logits.softmax(dim=-1).squeeze(0)
k = min(topk, probs.numel())
values, indices = torch.topk(probs, k=k)
top = [(self.id2label[i.item()], float(v.item())) for i, v in zip(indices, values)]
expected = sum(AGE_RANGE_TO_MID.get(self.id2label[i], 35) * float(p)
for i, p in enumerate(probs))
return expected, top
# ------------------ Face detection (single & group) ------------------
from facenet_pytorch import MTCNN
class FaceCropper:
"""
Detect faces.
- detect_one_wide: returns (crop_with_margin, annotated)
- detect_all_wide: returns (list[crops], annotated, list[boxes])
Boxes are (x1,y1,x2,y2) floats.
"""
def __init__(self, device: Optional[str] = None, margin_scale: float = 1.8):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.mtcnn = MTCNN(keep_all=True, device=self.device)
self.margin_scale = margin_scale
def _ensure_pil(self, img):
if isinstance(img, Image.Image):
return img.convert("RGB")
return Image.fromarray(img).convert("RGB")
def _expand_box(self, box, W, H, aspect=0.8): # ~4:5 portrait (w/h=0.8)
x1, y1, x2, y2 = box
cx, cy = (x1 + x2)/2, (y1 + y2)/2
w, h = (x2 - x1), (y2 - y1)
side = max(w, h) * self.margin_scale
tw = side
th = side / aspect # taller than wide
nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
return nx1, ny1, nx2, ny2
def detect_one_wide(self, img):
pil = self._ensure_pil(img)
W, H = pil.size
boxes, probs = self.mtcnn.detect(pil)
annotated = pil.copy()
draw = ImageDraw.Draw(annotated)
if boxes is None or len(boxes) == 0:
return None, annotated
# draw all boxes
for b, p in zip(boxes, probs):
bx1, by1, bx2, by2 = map(float, b)
draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
# choose largest
idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
nx1, ny1, nx2, ny2 = self._expand_box(boxes[idx], W, H)
crop = pil.crop((nx1, ny1, nx2, ny2))
return crop, annotated
def detect_all_wide(self, img):
pil = self._ensure_pil(img)
W, H = pil.size
boxes, probs = self.mtcnn.detect(pil)
annotated = pil.copy()
draw = ImageDraw.Draw(annotated)
crops = []
ordered = []
if boxes is None or len(boxes) == 0:
return crops, annotated, []
# sort roughly left->right for table order
for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2):
bx1, by1, bx2, by2 = map(float, b)
draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(0, 200, 255))
nx1, ny1, nx2, ny2 = self._expand_box(b, W, H)
crops.append(pil.crop((nx1, ny1, nx2, ny2)))
ordered.append((bx1, by1, bx2, by2))
return crops, annotated, ordered
# ------------------ FAST Cartoonizer (SD-Turbo) ------------------
from diffusers import AutoPipelineForImage2Image
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
TURBO_ID = "stabilityai/sd-turbo"
def load_turbo_pipe(device):
dtype = torch.float16 if (device == "cuda") else torch.float32
pipe = AutoPipelineForImage2Image.from_pretrained(
TURBO_ID,
dtype=dtype, # ✅ no deprecation warning
).to(device)
# safety checker ON for public Spaces
pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
pipe.feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
try:
pipe.enable_attention_slicing()
except Exception:
pass
return pipe
# init models once
age_est = PretrainedAgeEstimator()
cropper = FaceCropper(device=age_est.device, margin_scale=1.9)
sd_pipe = load_turbo_pipe(age_est.device)
# prompts
DEFAULT_POSITIVE = (
"beautiful princess portrait, elegant gown, tiara, soft magical lighting, "
"sparkles, dreamy castle background, painterly, clean lineart, vibrant but natural colors, "
"storybook illustration, high quality"
)
DEFAULT_NEGATIVE = (
"deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, "
"blurry, watermark, text, logo"
)
def _ensure_pil(img):
return img if isinstance(img, Image.Image) else Image.fromarray(img)
def _resize_512(im: Image.Image):
w, h = im.size
scale = 512 / max(w, h)
if scale < 1.0:
im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
return im
# ------------- AGE (single/group) -------------
@torch.inference_mode()
def predict_age(img, group_mode=False, auto_crop=True):
if img is None:
return {}, "Please upload an image.", None
pil = _ensure_pil(img).convert("RGB")
if group_mode:
crops, annotated, boxes = cropper.detect_all_wide(pil)
if not crops:
# fallback to full image
age, top = age_est.predict(pil, topk=5)
probs = {lbl: float(p) for lbl, p in top}
md = f"**Estimated age (whole image):** {age:.1f} years"
return probs, md, pil
# per-face ages
rows = ["| # | Age (yrs) | Top-1 | p |", "|---:|---:|---|---:|"]
for i, face in enumerate(crops, 1):
age, top = age_est.predict(face, topk=3)
top1, p1 = top[0]
rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |")
md = "\n".join(rows)
# also return a simple dict from the first face just to feed Label
age0, top0 = age_est.predict(crops[0], topk=5)
probs0 = {lbl: float(p) for lbl, p in top0}
return probs0, md, annotated
# single
face_wide = None; annotated = None
if auto_crop:
face_wide, annotated = cropper.detect_one_wide(pil)
target = face_wide if face_wide is not None else pil
age, top = age_est.predict(target, topk=5)
probs = {lbl: float(p) for lbl, p in top}
md = f"**Estimated age:** {age:.1f} years"
return probs, md, (annotated if annotated is not None else pil)
# ------------- CARTOON (single/group) -------------
@torch.inference_mode()
def cartoonize(img, prompt="", group_mode=False, auto_crop=True, strength=0.5, steps=2, seed=-1):
if img is None:
return None
pil = _ensure_pil(img).convert("RGB")
user = (prompt or "").strip()
pos = DEFAULT_POSITIVE if not user else f"{DEFAULT_POSITIVE}, {user}"
neg = DEFAULT_NEGATIVE
generator = None
if isinstance(seed, (int, float)) and int(seed) >= 0:
generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
if group_mode:
# detect all faces, stylize each, assemble grid
crops, _, _ = cropper.detect_all_wide(pil)
if not crops:
crops = [pil] # fallback
proc = []
for c in crops:
c = _resize_512(c)
out = sd_pipe(
prompt=pos, negative_prompt=neg, image=c,
strength=float(strength), guidance_scale=0.0,
num_inference_steps=int(steps), generator=generator
)
proc.append(out.images[0])
# tile into a grid
n = len(proc)
cols = int(math.ceil(math.sqrt(n)))
rows = int(math.ceil(n / cols))
cell_w = max(im.width for im in proc)
cell_h = max(im.height for im in proc)
grid = Image.new("RGB", (cols * cell_w, rows * cell_h), (240, 240, 240))
for i, im in enumerate(proc):
r, c = divmod(i, cols)
grid.paste(im, (c * cell_w, r * cell_h))
return grid
# single person
face_wide = None
if auto_crop:
face_wide, _ = cropper.detect_one_wide(pil)
base = face_wide if face_wide is not None else pil
base = _resize_512(base)
out = sd_pipe(
prompt=pos, negative_prompt=neg, image=base,
strength=float(strength), guidance_scale=0.0,
num_inference_steps=int(steps), generator=generator
)
return out.images[0]
# ------------------ UI ------------------
with gr.Blocks(title="Group Age + Cartoons (Fast)") as demo:
gr.Markdown("# Predict ages and make fast cartoons — single or group photos")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
group_mode = gr.Checkbox(False, label="Group photo (detect everyone)")
auto = gr.Checkbox(True, label="Auto face crop (wide)")
prompt = gr.Textbox(label="(Optional) Extra cartoon style",
placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette")
with gr.Row():
strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
btn_age = gr.Button("Predict Age(s) (fast)", variant="primary")
btn_cartoon = gr.Button("Make Cartoon(s) (fast)", variant="secondary")
with gr.Column(scale=1):
probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities, first face)")
age_md = gr.Markdown(label="Age Table / Summary")
preview = gr.Image(label="Detection Preview (boxes)")
cartoon_out = gr.Image(label="Cartoon Result (grid for groups)")
btn_age.click(fn=predict_age, inputs=[img_in, group_mode, auto], outputs=[probs_out, age_md, preview])
btn_cartoon.click(fn=cartoonize, inputs=[img_in, prompt, group_mode, auto, strength, steps, seed], outputs=cartoon_out)
# Expose for Hugging Face Spaces
app = demo
if __name__ == "__main__":
app.queue().launch()