CarolineM5 commited on
Commit
b3652cf
·
verified ·
1 Parent(s): 7402c4e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +18 -16
  2. inference.py +5 -15
app.py CHANGED
@@ -68,21 +68,22 @@ pipe = StableDiffusionInstructPix2PixPipeline(
68
 
69
  pipe = pipe.to(torch.float32).to(device)
70
 
71
- # --- 3) FONCTION GRADIO D’INTERFACE ---
72
- def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image, num_steps): # -> Image.Image
73
- """
74
- Cette fonction sera appelée à chaque upload par Gradio.
75
- Elle doit retourner une PIL.Image (ou un chemin vers l’image enregistrée).
76
- """
77
- # Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
78
  fibers_map = fibers_map.convert("RGB")
79
  rings_map = rings_map.convert("RGB")
80
-
81
- result_img = inference(pipe, rings_map, fibers_map, num_steps)
82
-
 
 
 
 
83
  return result_img
84
 
85
- # --- 4) DÉFINITION DE L’INTERFACE GRADIO ---
86
  iface = gr.Interface(
87
  fn=gradio_generate,
88
  inputs=[
@@ -90,13 +91,17 @@ iface = gr.Interface(
90
  gr.Image(type="pil", label="Growth ring map"),
91
  gr.Number(value=10, label="Number of inference steps")
92
  ],
93
- outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
 
 
 
 
94
  title="Photorealistic wood generator",
95
  description="""
96
  Upload :
97
  1) a fibre orientation map,
98
  2) a growth ring map.
99
-
100
  Set the number of inference steps.
101
  Higher values can improve quality but increase processing time.
102
 
@@ -104,10 +109,7 @@ iface = gr.Interface(
104
  """
105
  )
106
 
107
- # --- 5) LANCER L’APPLICATION ---
108
  if __name__ == "__main__":
109
- # Vous pouvez préciser `server_name="0.0.0.0"` si vous souhaitez qu’il soit accessible sur le réseau
110
- # et `server_port=7860` (ou autre port) si vous voulez le personnaliser.
111
  iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
112
 
113
 
 
68
 
69
  pipe = pipe.to(torch.float32).to(device)
70
 
71
+ def gradio_generate(fibers_map: Image.Image,
72
+ rings_map: Image.Image,
73
+ num_steps: int) -> Image.Image:
74
+ # 1) uniformiser le mode
 
 
 
75
  fibers_map = fibers_map.convert("RGB")
76
  rings_map = rings_map.convert("RGB")
77
+
78
+
79
+ # 3) appeler l'inference avec la seed
80
+ result_img = inference(pipe,
81
+ rings_map,
82
+ fibers_map,
83
+ num_steps)
84
  return result_img
85
 
86
+
87
  iface = gr.Interface(
88
  fn=gradio_generate,
89
  inputs=[
 
91
  gr.Image(type="pil", label="Growth ring map"),
92
  gr.Number(value=10, label="Number of inference steps")
93
  ],
94
+ outputs=gr.Image(
95
+ type="pil",
96
+ label="Photorealistic wood generated",
97
+ format="png" # ← force le .png au téléchargement
98
+ ),
99
  title="Photorealistic wood generator",
100
  description="""
101
  Upload :
102
  1) a fibre orientation map,
103
  2) a growth ring map.
104
+
105
  Set the number of inference steps.
106
  Higher values can improve quality but increase processing time.
107
 
 
109
  """
110
  )
111
 
 
112
  if __name__ == "__main__":
 
 
113
  iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
114
 
115
 
inference.py CHANGED
@@ -12,6 +12,7 @@ import numpy as np
12
  import torch.nn as nn
13
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
  from PIL import Image
 
15
 
16
  class UNetNoCondWrapper(nn.Module):
17
  def __init__(self, base_unet: UNet2DModel):
@@ -41,8 +42,11 @@ class UNetNoCondWrapper(nn.Module):
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
  def inference(pipe, img1, img2, num_steps):
 
 
 
44
 
45
- generator = torch.Generator("cpu").manual_seed(0)
46
 
47
  img1 = img1.resize((512, 512))
48
  img2 = img2.resize((512, 512))
@@ -65,20 +69,6 @@ def inference(pipe, img1, img2, num_steps):
65
 
66
  all_images = []
67
 
68
- # def cb_fn(step, timestep, latents):
69
- # # 1) Décoder
70
- # with torch.no_grad():
71
- # decoded_output = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
72
- # decoded_tensor = decoded_output.sample # (B, C, H, W)
73
-
74
- # # 2) Transformer en NumPy (channels last) et en uint8 [0–255]
75
- # t = decoded_tensor.cpu().clamp(0, 1)[0] # (C, H, W)
76
- # arr = (t.permute(1, 2, 0).numpy() * 255).astype(np.uint8) # (H, W, C)
77
-
78
- # # 3) Créer la PIL.Image
79
- # img = Image.fromarray(arr)
80
- # all_images.append(img)
81
-
82
  num_inference_steps = num_steps
83
  image_guidance_scale = 1.9
84
  guidance_scale = 10
 
12
  import torch.nn as nn
13
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
  from PIL import Image
15
+ import random
16
 
17
  class UNetNoCondWrapper(nn.Module):
18
  def __init__(self, base_unet: UNet2DModel):
 
42
  return self.unet.save_pretrained(save_directory, **kwargs)
43
 
44
  def inference(pipe, img1, img2, num_steps):
45
+
46
+ seed = random.randrange(0, 2**32)
47
+ torch.manual_seed(seed)
48
 
49
+ generator = torch.Generator("cpu").manual_seed(seed)
50
 
51
  img1 = img1.resize((512, 512))
52
  img2 = img2.resize((512, 512))
 
69
 
70
  all_images = []
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  num_inference_steps = num_steps
73
  image_guidance_scale = 1.9
74
  guidance_scale = 10