Stable-X
Update code
9dfa4de
raw
history blame
No virus
10.7 kB
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import functools
import os
import gradio as gr
import numpy as np
import torch as torch
from PIL import Image
import spaces
import diffusers
from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline
from stablenormal.pipeline_stablenormal import StableNormalPipeline
from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler
from data_utils import HWC3, resize_image
import sys
import cv2
sys.path.append('./geowizard')
from models.geowizard_pipeline import DepthNormalEstimationPipeline
class Geowizard(object):
'''
Simple Stable Diffusion Package
'''
def __init__(self):
self.model = DepthNormalEstimationPipeline.from_pretrained("weights/Geowizard/", torch_dtype=torch.float16)
def cuda(self):
self.model.cuda()
return self
def cpu(self):
self.model.cpu()
return self
def float(self):
self.model.float()
return self
def to(self, device):
self.model.to(device)
return self
def eval(self):
self.model.eval()
return self
def train(self):
self.model.train()
return self
@torch.no_grad()
def __call__(self, img, image_resolution=768):
pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)),
denoising_steps = 10,
ensemble_size= 1,
processing_res = image_resolution,
match_input_res = True,
domain = "indoor",
color_map = "Spectral",
show_progress_bar = False,
)
pred_normal = pipe_out.normal_np
pred_normal = (pred_normal + 1) / 2 * 255
pred_normal = pred_normal.astype(np.uint8)
return pred_normal
def __repr__(self):
return f"model: \n{self.model}"
class Marigold(Geowizard):
'''
Simple Stable Diffusion Package
'''
def __init__(self):
self.model= diffusers.MarigoldNormalsPipeline.from_pretrained("weights/marigold-normals-v0-1", torch_dtype=torch.float16)
@torch.no_grad()
def __call__(self, img, image_resolution=768):
pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
pred_normal = pipe_out.prediction[0]
pred_normal[..., 0] = -pred_normal[..., 0]
pred_normal = (pred_normal + 1) / 2 * 255
pred_normal = pred_normal.astype(np.uint8)
return pred_normal
def __repr__(self):
return f"model: \n{self.model}"
class StableNormal(Geowizard):
'''
Simple Stable Diffusion Package
'''
def __init__(self):
x_start_pipeline = YOSONormalsPipeline.from_pretrained('/workspace/code/InverseRendering/StableNormal/weights/yoso-normal-v0-2',
variant="fp16", torch_dtype=torch.float16)
self.model = StableNormalPipeline.from_pretrained('/workspace/code/InverseRendering/StableNormal/weights/stable-normal-v0-1',
variant="fp16", torch_dtype=torch.float16,
scheduler=HEURI_DDIMScheduler(prediction_type='sample',
beta_start=0.00085, beta_end=0.0120,
beta_schedule = "scaled_linear"))
# two stage concat
self.model.x_start_pipeline = x_start_pipeline
self.model.x_start_pipeline.to('cuda', torch.float16)
self.model.prior.to('cuda', torch.float16)
@torch.no_grad()
def __call__(self, img, image_resolution=768):
pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
pred_normal = pipe_out.prediction[0]
pred_normal = (pred_normal + 1) / 2 * 255
pred_normal = pred_normal.astype(np.uint8)
return pred_normal
def to(self, device):
self.model.to(device, torch.float16)
def __repr__(self):
return f"model: \n{self.model}"
class DSINE(object):
'''
Simple Stable Diffusion Package
'''
def __init__(self):
self.model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./models/dsine.pt', trust_repo=True)
def cuda(self):
self.model.cuda()
return self
def float(self):
self.model.float()
return self
def to(self, device):
self.model.to(device)
return self
def eval(self):
self.model.eval()
return self
def train(self):
self.model.train()
return self
@torch.no_grad()
def __call__(self, img, image_resolution=768):
pred_normal = self.model.infer_cv2(img)[0] # (3, H, W)
pred_normal = (pred_normal + 1) / 2 * 255
pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0)
# rgb
pred_normal = pred_normal.astype(np.uint8)
return pred_normal
def __repr__(self):
return f"model: \n{self.model}"
def process(
pipe_list,
path_input,
):
names = ['DSINE', 'Marigold', 'GeoWizard', 'StableNormal']
path_out_vis_list = []
for pipe in pipe_list:
try:
pipe.to('cuda')
except:
pass
img = cv2.imread(path_input)
raw_input_image = HWC3(img)
ori_H, ori_W, _ = raw_input_image.shape
img = resize_image(raw_input_image, 768)
pipe_out = pipe(
img,
768,
)
pred_normal= cv2.resize(pipe_out, (ori_W, ori_H))
path_out_vis_list.append(Image.fromarray(pred_normal))
try:
pipe.to('cpu')
except:
pass
_output = path_out_vis_list + [None] * (4 - len(path_out_vis_list))
yield _output
def run_demo_server(pipe):
process_pipe = spaces.GPU(functools.partial(process, pipe), duration=120)
os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
with gr.Blocks(
analytics_enabled=False,
title="Normal Estimation Comparison",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
""",
) as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="filepath",
height=256,
)
with gr.Column():
submit_btn = gr.Button(value="Compute normal", variant="primary")
clear_btn = gr.Button(value="Clear")
with gr.Row():
with gr.Column():
DSINE_output_slider = gr.Image(
label="DSINE",
type="filepath",
)
with gr.Column():
marigold_output_slider = gr.Image(
label="Marigold",
type="filepath",
)
with gr.Column():
geowizard_output_slider = gr.Image(
label="Geowizard",
type="filepath",
)
with gr.Column():
Ours_slider = gr.Image(
label="StableNormal",
type="filepath",
)
outputs = [
DSINE_output_slider,
marigold_output_slider,
geowizard_output_slider,
Ours_slider,
]
submit_btn.click(
fn=process_pipe,
inputs=input_image,
outputs=outputs,
concurrency_limit=1,
)
gr.Examples(
fn=process_pipe,
examples=sorted([
os.path.join("files", "images", name)
for name in os.listdir(os.path.join("files", "images"))
]),
inputs=input_image,
outputs=outputs,
cache_examples=False,
)
def clear_fn():
out = []
out += [
gr.Button(interactive=True),
gr.Button(interactive=True),
gr.Image(value=None, interactive=True),
None,
None,
None,
None,
None,
None,
]
return out
clear_btn.click(
fn=clear_fn,
inputs=[],
outputs=
[
submit_btn,
input_image,
marigold_output_slider,
geowizard_output_slider,
DSINE_output_slider,
Ours_slider,
],
)
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
marigold_pipe = Marigold()
geowizard_pipe = Geowizard()
dsine_pipe = DSINE()
our_pipe = StableNormal()
run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe])
if __name__ == "__main__":
main()