danube2024 commited on
Commit
ddef426
·
verified ·
1 Parent(s): cb71cae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -1,41 +1,31 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
-
5
  from diffusers import StableDiffusionXLPipeline
6
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
7
  from PIL import Image, ImageEnhance, ImageOps
8
 
9
- ############################################
10
- # 1. Setup and Model Loading
11
- ############################################
12
- device = "cpu" # or "cuda" if GPU is available
13
- torch_dtype = torch.float32 # if using CPU or float16 for GPU
14
 
15
  print("Loading SDXL Base model...")
16
  pipe = StableDiffusionXLPipeline.from_pretrained(
17
  "stabilityai/stable-diffusion-xl-base-1.0",
18
  torch_dtype=torch_dtype
19
- )
20
- pipe.to(device)
21
 
22
- print("Loading bas-relief LoRA weights...")
23
- # IMPORTANT: Pass the first argument as a string to the repo or path,
24
- # and `weight_name` as a kwarg. That matches the actual function signature.
25
  pipe.load_lora_weights(
26
- "KappaNeuro/bas-relief", # repo / path
27
- weight_name="BAS-RELIEF.safetensors"
 
28
  )
29
 
30
- print("Loading DPT Depth model...")
31
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
32
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
33
 
34
-
35
  def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
36
- """
37
- Normalize depth to [0, 255], auto-contrast, and sharpen.
38
- """
39
  d_min, d_max = depth_arr.min(), depth_arr.max()
40
  depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
41
  depth_stretched = (depth_stretched * 255).astype(np.uint8)
@@ -48,52 +38,47 @@ def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
48
 
49
  return depth_pil
50
 
51
-
52
- def generate_bas_relief_and_depth(prompt: str):
53
- # We prepend "BAS-RELIEF" to ensure the LoRA style is triggered.
54
  full_prompt = f"BAS-RELIEF {prompt}"
55
-
56
- print("Generating bas-relief image...")
57
  result = pipe(
58
  prompt=full_prompt,
59
- num_inference_steps=15, # Lower for speed on CPU
60
  guidance_scale=7.5,
61
- height=512,
62
  width=512
63
  )
64
- generated_image = result.images[0]
65
 
66
- print("Running depth estimation...")
67
- inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
68
  with torch.no_grad():
69
  outputs = depth_model(**inputs)
70
  predicted_depth = outputs.predicted_depth
71
 
72
- # Resize depth map to match original image
73
  prediction = torch.nn.functional.interpolate(
74
  predicted_depth.unsqueeze(1),
75
- size=generated_image.size[::-1],
76
  mode="bicubic",
77
- align_corners=False,
78
- ).squeeze(0)
79
-
80
- depth_arr = prediction.cpu().numpy()
81
- depth_pil = enhance_depth_map(depth_arr)
82
 
83
- return generated_image, depth_pil
84
 
 
85
 
86
- title = "Bas-Relief (SDXL + LoRA) + Depth Map"
87
  description = (
88
- "Load SDXL base on CPU, apply 'BAS-RELIEF.safetensors' LoRA from KappaNeuro/bas-relief. "
89
- "Then run DPT for depth estimation."
90
  )
91
 
92
  iface = gr.Interface(
93
  fn=generate_bas_relief_and_depth,
94
  inputs=gr.Textbox(
95
- label="Describe your scene/style",
96
- placeholder="e.g., 'sculpture of a woman in shibari, marble, intricate details'"
97
  ),
98
  outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
99
  title=title,
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
4
  from diffusers import StableDiffusionXLPipeline
5
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
6
  from PIL import Image, ImageEnhance, ImageOps
7
 
8
+ device = "cpu" # or "cuda" if you have a GPU
9
+ torch_dtype = torch.float32
 
 
 
10
 
11
  print("Loading SDXL Base model...")
12
  pipe = StableDiffusionXLPipeline.from_pretrained(
13
  "stabilityai/stable-diffusion-xl-base-1.0",
14
  torch_dtype=torch_dtype
15
+ ).to(device)
 
16
 
17
+ print("Loading bas-relief LoRA weights with PEFT...")
 
 
18
  pipe.load_lora_weights(
19
+ "KappaNeuro/bas-relief", # The HF repo with BAS-RELIEF.safetensors
20
+ weight_name="BAS-RELIEF.safetensors",
21
+ peft_backend="peft" # This is crucial
22
  )
23
 
24
+ print("Loading DPT Depth Model...")
25
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
26
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
27
 
 
28
  def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
 
 
 
29
  d_min, d_max = depth_arr.min(), depth_arr.max()
30
  depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
31
  depth_stretched = (depth_stretched * 255).astype(np.uint8)
 
38
 
39
  return depth_pil
40
 
41
+ def generate_bas_relief_and_depth(prompt):
42
+ # Use the token "BAS-RELIEF" so the LoRA triggers
 
43
  full_prompt = f"BAS-RELIEF {prompt}"
44
+ print("Generating image with LoRA style...")
 
45
  result = pipe(
46
  prompt=full_prompt,
47
+ num_inference_steps=15, # reduce if too slow
48
  guidance_scale=7.5,
49
+ height=512, # reduce if you still get timeouts
50
  width=512
51
  )
52
+ image = result.images[0]
53
 
54
+ print("Running DPT Depth Estimation...")
55
+ inputs = feature_extractor(image, return_tensors="pt").to(device)
56
  with torch.no_grad():
57
  outputs = depth_model(**inputs)
58
  predicted_depth = outputs.predicted_depth
59
 
 
60
  prediction = torch.nn.functional.interpolate(
61
  predicted_depth.unsqueeze(1),
62
+ size=image.size[::-1],
63
  mode="bicubic",
64
+ align_corners=False
65
+ ).squeeze()
 
 
 
66
 
67
+ depth_map_pil = enhance_depth_map(prediction.cpu().numpy())
68
 
69
+ return image, depth_map_pil
70
 
71
+ title = "Bas-Relief (SDXL + LoRA) + Depth Map (with PEFT)"
72
  description = (
73
+ "Loads stable-diffusion-xl-base-1.0 on CPU, merges LoRA from 'KappaNeuro/bas-relief'. "
74
+ "Use 'BAS-RELIEF' token in your prompt to trigger the style, then compute a depth map."
75
  )
76
 
77
  iface = gr.Interface(
78
  fn=generate_bas_relief_and_depth,
79
  inputs=gr.Textbox(
80
+ label="Description",
81
+ placeholder="woman in shibari, marble relief, intricately carved"
82
  ),
83
  outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
84
  title=title,