| """
|
| MatFuse β PBR Material Generation Demo
|
|
|
| Gradio app for generating physically-based rendering (PBR) material maps
|
| using the MatFuse diffusion model. Supports text, image, sketch, and
|
| color-palette conditioning.
|
|
|
| Designed for Hugging Face Spaces with ZeroGPU support.
|
| """
|
|
|
| import os
|
| import random
|
| from typing import Optional
|
|
|
| import gradio as gr
|
| import numpy as np
|
| import spaces
|
| import torch
|
| from diffusers import DiffusionPipeline
|
| from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
| REPO_ID = os.environ.get("MATFUSE_REPO", "gvecchio/MatFuse")
|
|
|
| pipe = DiffusionPipeline.from_pretrained(
|
| REPO_ID,
|
| trust_remote_code=True,
|
| torch_dtype=torch.float16,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| def extract_palette(image: Image.Image, n_colors: int = 5) -> list[list[int]]:
|
| """Extract dominant colors from an image using simple K-Means."""
|
| img = image.convert("RGB").resize((64, 64))
|
| pixels = np.array(img).reshape(-1, 3).astype(np.float32)
|
|
|
|
|
| rng = np.random.default_rng(0)
|
| centroids = pixels[rng.choice(len(pixels), n_colors, replace=False)]
|
| for _ in range(20):
|
| dists = np.linalg.norm(pixels[:, None] - centroids[None], axis=2)
|
| labels = dists.argmin(axis=1)
|
| for k in range(n_colors):
|
| mask = labels == k
|
| if mask.any():
|
| centroids[k] = pixels[mask].mean(axis=0)
|
| return centroids.clip(0, 255).astype(np.uint8).tolist()
|
|
|
|
|
| def palette_to_image(colors: list[list[int]], height: int = 50) -> Image.Image:
|
| """Render a palette swatch strip."""
|
| n = len(colors)
|
| w_each = 60
|
| img = Image.new("RGB", (w_each * n, height))
|
| for i, c in enumerate(colors):
|
| for x in range(w_each):
|
| for y in range(height):
|
| img.putpixel((i * w_each + x, y), tuple(c))
|
| return img
|
|
|
|
|
|
|
|
|
|
|
|
|
| @spaces.GPU
|
| @torch.inference_mode()
|
| def generate(
|
| prompt: Optional[str],
|
| image: Optional[Image.Image],
|
| palette_image: Optional[Image.Image],
|
| sketch: Optional[Image.Image],
|
| guidance_scale: float,
|
| num_steps: int,
|
| seed: int,
|
| randomize_seed: bool,
|
| ):
|
| """Run the MatFuse pipeline and return the four PBR maps + palette preview."""
|
|
|
| if randomize_seed:
|
| seed = random.randint(0, 2**31 - 1)
|
|
|
|
|
| pipe.to("cuda")
|
|
|
|
|
| kwargs: dict = dict(
|
| num_inference_steps=num_steps,
|
| guidance_scale=guidance_scale,
|
| generator=torch.Generator("cuda").manual_seed(seed),
|
| )
|
|
|
|
|
| if prompt and prompt.strip():
|
| kwargs["text"] = prompt.strip()
|
|
|
|
|
| if image is not None:
|
| kwargs["image"] = image
|
|
|
|
|
| if sketch is not None:
|
| kwargs["sketch"] = sketch
|
|
|
|
|
| palette_preview = None
|
| if palette_image is not None:
|
| colors = extract_palette(palette_image, n_colors=5)
|
| palette_arr = np.array(colors, dtype=np.float32) / 255.0
|
| kwargs["palette"] = palette_arr
|
| palette_preview = palette_to_image(colors)
|
|
|
| result = pipe(**kwargs)
|
|
|
| diffuse_img = result["diffuse"][0]
|
| normal_img = result["normal"][0]
|
| roughness_img = result["roughness"][0]
|
| specular_img = result["specular"][0]
|
|
|
| return diffuse_img, normal_img, roughness_img, specular_img, palette_preview, seed
|
|
|
|
|
|
|
|
|
|
|
|
|
| EXAMPLE_PROMPTS = [
|
| "Red brick wall with white mortar",
|
| "Polished oak wood floor",
|
| "Rough concrete with cracks",
|
| "Mossy cobblestone path",
|
| "Shiny marble tiles",
|
| "Rusted metal panel",
|
| ]
|
|
|
|
|
|
|
|
|
|
|
|
|
| css = """
|
| #matfuse-title { text-align: center; margin-bottom: 0.5em; }
|
| #matfuse-subtitle { text-align: center; color: #666; margin-top: 0; }
|
| .output-map img { border-radius: 8px; }
|
| footer { display: none !important; }
|
| """
|
|
|
| with gr.Blocks(title="MatFuse β PBR Material Generator") as demo:
|
|
|
|
|
| gr.Markdown("# MatFuse", elem_id="matfuse-title")
|
| gr.Markdown(
|
| "Generate seamless PBR material maps (diffuse, normal, roughness, specular) "
|
| "from text, images, sketches, and color palettes. "
|
| "[Paper](https://arxiv.org/abs/2308.11408) | "
|
| "[Code](https://github.com/gvecchio/matfuse-sd)",
|
| elem_id="matfuse-subtitle",
|
| )
|
|
|
| with gr.Row():
|
|
|
| with gr.Column(scale=1):
|
| prompt = gr.Textbox(
|
| label="Text prompt",
|
| placeholder="e.g. 'Old wooden floor with scratches'",
|
| lines=2,
|
| )
|
|
|
| with gr.Accordion("Image conditioning", open=False):
|
| image_input = gr.Image(
|
| label="Reference image",
|
| type="pil",
|
| sources=["upload", "clipboard"],
|
| )
|
|
|
| with gr.Accordion("Palette conditioning", open=False):
|
| palette_image = gr.Image(
|
| label="Upload image to extract palette",
|
| type="pil",
|
| sources=["upload", "clipboard"],
|
| )
|
|
|
| with gr.Accordion("Sketch conditioning", open=False):
|
| sketch_input = gr.Image(
|
| label="Binary sketch / edge map",
|
| type="pil",
|
| image_mode="L",
|
| sources=["upload", "clipboard"],
|
| )
|
|
|
| with gr.Accordion("Generation settings", open=False):
|
| guidance_scale = gr.Slider(
|
| label="Guidance scale",
|
| minimum=1.0,
|
| maximum=15.0,
|
| value=4.0,
|
| step=0.5,
|
| )
|
| num_steps = gr.Slider(
|
| label="Inference steps",
|
| minimum=10,
|
| maximum=100,
|
| value=50,
|
| step=5,
|
| )
|
| with gr.Row():
|
| seed = gr.Number(label="Seed", value=42, precision=0)
|
| randomize_seed = gr.Checkbox(label="Randomize", value=True)
|
|
|
| generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
|
|
| gr.Examples(
|
| examples=[[p] for p in EXAMPLE_PROMPTS],
|
| inputs=[prompt],
|
| label="Example prompts",
|
| )
|
|
|
|
|
| with gr.Column(scale=1):
|
| with gr.Row():
|
| diffuse_out = gr.Image(label="Diffuse", elem_classes="output-map", interactive=False)
|
| normal_out = gr.Image(label="Normal", elem_classes="output-map", interactive=False)
|
| with gr.Row():
|
| roughness_out = gr.Image(label="Roughness", elem_classes="output-map", interactive=False)
|
| specular_out = gr.Image(label="Specular", elem_classes="output-map", interactive=False)
|
| palette_out = gr.Image(label="Extracted palette", visible=True, height=60, interactive=False)
|
| seed_out = gr.Number(label="Seed used", interactive=False)
|
|
|
|
|
| generate_btn.click(
|
| fn=generate,
|
| inputs=[
|
| prompt,
|
| image_input,
|
| palette_image,
|
| sketch_input,
|
| guidance_scale,
|
| num_steps,
|
| seed,
|
| randomize_seed,
|
| ],
|
| outputs=[diffuse_out, normal_out, roughness_out, specular_out, palette_out, seed_out],
|
| )
|
|
|
| if __name__ == "__main__":
|
| demo.launch(css=css, theme=gr.themes.Soft())
|
|
|
|
|
| |