Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,31 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
|
5 |
from diffusers import StableDiffusionXLPipeline
|
6 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
7 |
from PIL import Image, ImageEnhance, ImageOps
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
############################################
|
12 |
-
device = "cpu" # or "cuda" if GPU is available
|
13 |
-
torch_dtype = torch.float32 # if using CPU or float16 for GPU
|
14 |
|
15 |
print("Loading SDXL Base model...")
|
16 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
17 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
18 |
torch_dtype=torch_dtype
|
19 |
-
)
|
20 |
-
pipe.to(device)
|
21 |
|
22 |
-
print("Loading bas-relief LoRA weights...")
|
23 |
-
# IMPORTANT: Pass the first argument as a string to the repo or path,
|
24 |
-
# and `weight_name` as a kwarg. That matches the actual function signature.
|
25 |
pipe.load_lora_weights(
|
26 |
-
"KappaNeuro/bas-relief",
|
27 |
-
weight_name="BAS-RELIEF.safetensors"
|
|
|
28 |
)
|
29 |
|
30 |
-
print("Loading DPT Depth
|
31 |
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
32 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
33 |
|
34 |
-
|
35 |
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
36 |
-
"""
|
37 |
-
Normalize depth to [0, 255], auto-contrast, and sharpen.
|
38 |
-
"""
|
39 |
d_min, d_max = depth_arr.min(), depth_arr.max()
|
40 |
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
41 |
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
@@ -48,52 +38,47 @@ def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
|
48 |
|
49 |
return depth_pil
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
# We prepend "BAS-RELIEF" to ensure the LoRA style is triggered.
|
54 |
full_prompt = f"BAS-RELIEF {prompt}"
|
55 |
-
|
56 |
-
print("Generating bas-relief image...")
|
57 |
result = pipe(
|
58 |
prompt=full_prompt,
|
59 |
-
num_inference_steps=15,
|
60 |
guidance_scale=7.5,
|
61 |
-
height=512,
|
62 |
width=512
|
63 |
)
|
64 |
-
|
65 |
|
66 |
-
print("Running
|
67 |
-
inputs = feature_extractor(
|
68 |
with torch.no_grad():
|
69 |
outputs = depth_model(**inputs)
|
70 |
predicted_depth = outputs.predicted_depth
|
71 |
|
72 |
-
# Resize depth map to match original image
|
73 |
prediction = torch.nn.functional.interpolate(
|
74 |
predicted_depth.unsqueeze(1),
|
75 |
-
size=
|
76 |
mode="bicubic",
|
77 |
-
align_corners=False
|
78 |
-
).squeeze(
|
79 |
-
|
80 |
-
depth_arr = prediction.cpu().numpy()
|
81 |
-
depth_pil = enhance_depth_map(depth_arr)
|
82 |
|
83 |
-
|
84 |
|
|
|
85 |
|
86 |
-
title = "Bas-Relief (SDXL + LoRA) + Depth Map"
|
87 |
description = (
|
88 |
-
"
|
89 |
-
"
|
90 |
)
|
91 |
|
92 |
iface = gr.Interface(
|
93 |
fn=generate_bas_relief_and_depth,
|
94 |
inputs=gr.Textbox(
|
95 |
-
label="
|
96 |
-
placeholder="
|
97 |
),
|
98 |
outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
|
99 |
title=title,
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
from diffusers import StableDiffusionXLPipeline
|
5 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
6 |
from PIL import Image, ImageEnhance, ImageOps
|
7 |
|
8 |
+
device = "cpu" # or "cuda" if you have a GPU
|
9 |
+
torch_dtype = torch.float32
|
|
|
|
|
|
|
10 |
|
11 |
print("Loading SDXL Base model...")
|
12 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
13 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
14 |
torch_dtype=torch_dtype
|
15 |
+
).to(device)
|
|
|
16 |
|
17 |
+
print("Loading bas-relief LoRA weights with PEFT...")
|
|
|
|
|
18 |
pipe.load_lora_weights(
|
19 |
+
"KappaNeuro/bas-relief", # The HF repo with BAS-RELIEF.safetensors
|
20 |
+
weight_name="BAS-RELIEF.safetensors",
|
21 |
+
peft_backend="peft" # This is crucial
|
22 |
)
|
23 |
|
24 |
+
print("Loading DPT Depth Model...")
|
25 |
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
26 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
27 |
|
|
|
28 |
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
|
|
|
|
|
|
29 |
d_min, d_max = depth_arr.min(), depth_arr.max()
|
30 |
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
31 |
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
|
|
38 |
|
39 |
return depth_pil
|
40 |
|
41 |
+
def generate_bas_relief_and_depth(prompt):
|
42 |
+
# Use the token "BAS-RELIEF" so the LoRA triggers
|
|
|
43 |
full_prompt = f"BAS-RELIEF {prompt}"
|
44 |
+
print("Generating image with LoRA style...")
|
|
|
45 |
result = pipe(
|
46 |
prompt=full_prompt,
|
47 |
+
num_inference_steps=15, # reduce if too slow
|
48 |
guidance_scale=7.5,
|
49 |
+
height=512, # reduce if you still get timeouts
|
50 |
width=512
|
51 |
)
|
52 |
+
image = result.images[0]
|
53 |
|
54 |
+
print("Running DPT Depth Estimation...")
|
55 |
+
inputs = feature_extractor(image, return_tensors="pt").to(device)
|
56 |
with torch.no_grad():
|
57 |
outputs = depth_model(**inputs)
|
58 |
predicted_depth = outputs.predicted_depth
|
59 |
|
|
|
60 |
prediction = torch.nn.functional.interpolate(
|
61 |
predicted_depth.unsqueeze(1),
|
62 |
+
size=image.size[::-1],
|
63 |
mode="bicubic",
|
64 |
+
align_corners=False
|
65 |
+
).squeeze()
|
|
|
|
|
|
|
66 |
|
67 |
+
depth_map_pil = enhance_depth_map(prediction.cpu().numpy())
|
68 |
|
69 |
+
return image, depth_map_pil
|
70 |
|
71 |
+
title = "Bas-Relief (SDXL + LoRA) + Depth Map (with PEFT)"
|
72 |
description = (
|
73 |
+
"Loads stable-diffusion-xl-base-1.0 on CPU, merges LoRA from 'KappaNeuro/bas-relief'. "
|
74 |
+
"Use 'BAS-RELIEF' token in your prompt to trigger the style, then compute a depth map."
|
75 |
)
|
76 |
|
77 |
iface = gr.Interface(
|
78 |
fn=generate_bas_relief_and_depth,
|
79 |
inputs=gr.Textbox(
|
80 |
+
label="Description",
|
81 |
+
placeholder="woman in shibari, marble relief, intricately carved"
|
82 |
),
|
83 |
outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
|
84 |
title=title,
|