Update app.py
Browse files
app.py
CHANGED
@@ -1,33 +1,58 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
|
|
|
5 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
6 |
from PIL import Image, ImageEnhance, ImageOps
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
#
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
torch_dtype=torch_dtype
|
15 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
#
|
|
|
|
|
18 |
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
19 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
d_min, d_max = depth_arr.min(), depth_arr.max()
|
24 |
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
25 |
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
26 |
|
27 |
-
# Convert to PIL for further post-processing
|
28 |
depth_pil = Image.fromarray(depth_stretched)
|
29 |
|
30 |
-
#
|
31 |
depth_pil = ImageOps.autocontrast(depth_pil)
|
32 |
|
33 |
# Sharpen
|
@@ -36,45 +61,82 @@ def enhance_depth(depth_arr):
|
|
36 |
|
37 |
return depth_pil
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
).
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
#
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
with torch.no_grad():
|
55 |
outputs = depth_model(**inputs)
|
56 |
-
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
size=
|
62 |
mode="bicubic",
|
63 |
-
align_corners=False
|
64 |
-
).squeeze()
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
depth_arr = depth_resized.cpu().numpy()
|
68 |
-
depth_pil = enhance_depth(depth_arr)
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
iface = gr.Interface(
|
73 |
fn=generate_bas_relief_and_depth,
|
74 |
-
inputs=
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
|
80 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
|
5 |
+
from diffusers import StableDiffusionXLPipeline
|
6 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
7 |
from PIL import Image, ImageEnhance, ImageOps
|
8 |
|
9 |
+
############################################
|
10 |
+
# 1. Setup and Model Loading
|
11 |
+
############################################
|
12 |
+
device = "cpu" # or "cuda" if GPU is available
|
13 |
+
torch_dtype = torch.float32 # if using CPU or float16 for GPU
|
14 |
+
|
15 |
+
# --- Load Base SDXL Model ---
|
16 |
+
# (Large model, be sure you have enough memory or use fewer steps)
|
17 |
+
print("Loading SDXL Base model...")
|
18 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
19 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
20 |
torch_dtype=torch_dtype
|
21 |
+
)
|
22 |
+
pipe.to(device)
|
23 |
+
|
24 |
+
# --- Load LoRA Weights from KappaNeuro/bas-relief ---
|
25 |
+
# The safetensors file is named "BAS-RELIEF.safetensors"
|
26 |
+
# This merges the LoRA into the pipeline so you can use it via the "BAS-RELIEF" token
|
27 |
+
print("Loading bas-relief LoRA weights...")
|
28 |
+
pipe.load_lora_weights(
|
29 |
+
repo_id_or_path="KappaNeuro/bas-relief",
|
30 |
+
weight_name="BAS-RELIEF.safetensors"
|
31 |
+
)
|
32 |
|
33 |
+
# --- Load Depth Estimation Model ---
|
34 |
+
# We'll use Intel's DPT for depth. On CPU, it's also relatively large, so be cautious of performance.
|
35 |
+
print("Loading DPT Depth model...")
|
36 |
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
37 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
38 |
|
39 |
+
|
40 |
+
############################################
|
41 |
+
# 2. Depth Map Enhancement (PIL-based)
|
42 |
+
############################################
|
43 |
+
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
44 |
+
"""
|
45 |
+
- Normalize depth to [0, 255]
|
46 |
+
- Auto-contrast to emphasize details
|
47 |
+
- Sharpen edges
|
48 |
+
"""
|
49 |
d_min, d_max = depth_arr.min(), depth_arr.max()
|
50 |
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
51 |
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
52 |
|
|
|
53 |
depth_pil = Image.fromarray(depth_stretched)
|
54 |
|
55 |
+
# Auto-contrast
|
56 |
depth_pil = ImageOps.autocontrast(depth_pil)
|
57 |
|
58 |
# Sharpen
|
|
|
61 |
|
62 |
return depth_pil
|
63 |
|
64 |
+
|
65 |
+
############################################
|
66 |
+
# 3. Generation + Depth Inference Function
|
67 |
+
############################################
|
68 |
+
def generate_bas_relief_and_depth(prompt: str):
|
69 |
+
"""
|
70 |
+
1) Generate a 'bas-relief' style image using the LoRA from KappaNeuro/bas-relief.
|
71 |
+
- Must include "BAS-RELIEF" token in the prompt for the style to apply.
|
72 |
+
2) Compute a depth map using Intel/DPT-Large.
|
73 |
+
3) Return (image, depth_map).
|
74 |
+
"""
|
75 |
+
|
76 |
+
# -- Step A: Merge the user's prompt with "BAS-RELIEF" instance token --
|
77 |
+
# You can experiment with different prompt styles:
|
78 |
+
# e.g. "BAS-RELIEF sculpture of a woman in shibari, marble, octane render..."
|
79 |
+
full_prompt = f"BAS-RELIEF {prompt}"
|
80 |
+
|
81 |
+
# -- Step B: Generate the image with SDXL + LoRA
|
82 |
+
# Keep resolution modest to avoid timeouts on CPU
|
83 |
+
print("Generating bas-relief image...")
|
84 |
+
result = pipe(
|
85 |
+
prompt=full_prompt,
|
86 |
+
num_inference_steps=15, # Lower steps => faster (but lower quality)
|
87 |
+
guidance_scale=7.5,
|
88 |
+
height=512, # can reduce to e.g. 384 if still too slow
|
89 |
+
width=512
|
90 |
+
)
|
91 |
+
|
92 |
+
# Extract image from pipeline result
|
93 |
+
generated_image = result.images[0]
|
94 |
+
|
95 |
+
# -- Step C: Depth Estimation with DPT
|
96 |
+
print("Running depth estimation...")
|
97 |
+
inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
|
98 |
+
|
99 |
with torch.no_grad():
|
100 |
outputs = depth_model(**inputs)
|
101 |
+
predicted_depth = outputs.predicted_depth # shape: [batch, height, width]
|
102 |
|
103 |
+
# Resize to match original image resolution
|
104 |
+
prediction = torch.nn.functional.interpolate(
|
105 |
+
predicted_depth.unsqueeze(1),
|
106 |
+
size=generated_image.size[::-1],
|
107 |
mode="bicubic",
|
108 |
+
align_corners=False,
|
109 |
+
).squeeze(0)
|
110 |
+
|
111 |
+
depth_arr = prediction.cpu().numpy()
|
112 |
+
depth_map_pil = enhance_depth_map(depth_arr)
|
113 |
|
114 |
+
return generated_image, depth_map_pil
|
|
|
|
|
115 |
|
116 |
+
|
117 |
+
############################################
|
118 |
+
# 4. Gradio Interface
|
119 |
+
############################################
|
120 |
+
title = "Bas-Relief with SDXL + LoRA + Depth Map"
|
121 |
+
description = (
|
122 |
+
"This demo loads SDXL-base on CPU (slow!) and merges LoRA from KappaNeuro/bas-relief. "
|
123 |
+
"Use 'BAS-RELIEF' in your prompt for the style. Then we generate a depth map using DPT."
|
124 |
+
"Lower resolution or fewer steps if you get timeouts."
|
125 |
+
)
|
126 |
|
127 |
iface = gr.Interface(
|
128 |
fn=generate_bas_relief_and_depth,
|
129 |
+
inputs=gr.Textbox(
|
130 |
+
label="Describe your scene/style",
|
131 |
+
placeholder="sculpture of a woman in shibari, marble, intricate details"
|
132 |
+
),
|
133 |
+
outputs=[
|
134 |
+
gr.Image(label="Bas-Relief Image"),
|
135 |
+
gr.Image(label="Depth Map"),
|
136 |
+
],
|
137 |
+
title=title,
|
138 |
+
description=description
|
139 |
)
|
140 |
|
141 |
+
if __name__ == "__main__":
|
142 |
+
iface.launch()
|