gvecchio commited on
Commit
9a13713
1 Parent(s): 0c16875
Files changed (4) hide show
  1. README.md +3 -3
  2. app.py +57 -0
  3. generation.py +51 -0
  4. requirements.txt +6 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: StableMaterials
3
- emoji: 🦀
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.32.2
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: StableMaterials
3
+ emoji: 🧱
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import diffusers
5
+
6
+ from generation import generate_material
7
+
8
+
9
+ @spaces.GPU
10
+ def generate(prompts, seed, resolution, refinement):
11
+ image = generate_material(prompts, seed=seed, resolution=int(resolution), refinement=refinement)
12
+ return image.basecolor, image.normal, image.height, image.metallic, image.roughness
13
+
14
+
15
+ def interface_function(prompt_type, text_prompt, image_prompt, seed, resolution, refinement):
16
+ if prompt_type == "Text":
17
+ return generate(text_prompt, seed, resolution, refinement)
18
+ elif prompt_type == "Image":
19
+ return generate(image_prompt, seed, resolution, refinement)
20
+
21
+ def update_visibility(prompt_type):
22
+ if prompt_type == "Text":
23
+ return gr.update(visible=True), gr.update(visible=False)
24
+ elif prompt_type == "Image":
25
+ return gr.update(visible=False), gr.update(visible=True)
26
+
27
+ with gr.Blocks() as demo:
28
+ with gr.Row():
29
+ with gr.Column():
30
+ prompt_type = gr.Radio(choices=["Text", "Image"], label="Prompt Type", value="Text")
31
+ text_prompt = gr.Textbox(label="Text Prompt", visible=True, lines=3, placeholder="A brick wall")
32
+ image_prompt = gr.Image(type="pil", label="Image Prompt", visible=False)
33
+
34
+ with gr.Column():
35
+ seed = gr.Number(value=-1, label="Seed (-1 for random)")
36
+ resolution = gr.Dropdown(["512", "1024", "2048"], value="512", label="Resolution", interactive=False)
37
+ refinement = gr.Checkbox(label="Refinement", interactive=False)
38
+ generate_button = gr.Button("Generate")
39
+
40
+ prompt_type.change(fn=update_visibility, inputs=prompt_type, outputs=[text_prompt, image_prompt])
41
+
42
+ with gr.Row():
43
+ output_basecolor = gr.Image(label="Base Color", format="png", image_mode="RGB")
44
+ output_normal = gr.Image(label="Normal Map", format="png", image_mode="RGB")
45
+ output_height = gr.Image(label="Height Map", format="png", image_mode="L")
46
+ output_metallic = gr.Image(label="Metallic Map", format="png", image_mode="L")
47
+ output_roughness = gr.Image(label="Roughness Map", format="png", image_mode="L")
48
+
49
+ generate_button.click(
50
+ fn=interface_function,
51
+ inputs=[prompt_type, text_prompt, image_prompt, seed, resolution, refinement],
52
+ outputs=[output_basecolor, output_normal, output_height, output_metallic, output_roughness]
53
+ )
54
+
55
+ if __name__ == "__main__":
56
+ demo.launch()
57
+
generation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline, LCMScheduler, UNet2DConditionModel
3
+
4
+ import logging
5
+
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ unet = UNet2DConditionModel.from_pretrained(
13
+ "gvecchio/StableMaterials",
14
+ subfolder="unet_lcm",
15
+ torch_dtype=torch.float16,
16
+ )
17
+
18
+ pipe = DiffusionPipeline.from_pretrained(
19
+ "gvecchio/StableMaterials",
20
+ trust_remote_code=True,
21
+ unet=unet,
22
+ torch_dtype=torch.float16
23
+ ).to(device)
24
+
25
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
26
+
27
+
28
+ def generate_material(prompt, seed=-1, resolution=512, refinement=False):
29
+ try:
30
+ seed = seed if seed != -1 else torch.randint(0, 10000, (1,)).item()
31
+
32
+ logger.info(f"Generating images for prompt: {prompt} with seed: {seed}")
33
+ generator = torch.Generator(device=pipe.device).manual_seed(seed)
34
+
35
+ image = pipe(
36
+ prompt=[prompt],
37
+ tileable=True,
38
+ num_images_per_prompt=1,
39
+ num_inference_steps=4,
40
+ generator=generator,
41
+ ).images[0]
42
+
43
+ image = image.resize((resolution, resolution))
44
+
45
+ if refinement:
46
+ pass
47
+
48
+ return image
49
+ except Exception as e:
50
+ logger.error(f"Exception occurred while generating images: {e}")
51
+ raise
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ invisible_watermark
4
+ torch
5
+ torchvision
6
+ transformers