File size: 3,171 Bytes
44189a1 c71b96e 44189a1 dc78df8 44189a1 c71b96e 44189a1 dc78df8 c71b96e 44189a1 dc78df8 44189a1 dc78df8 c71b96e dc78df8 c71b96e dc78df8 c71b96e 44189a1 c71b96e 44189a1 8c25de0 73b0806 8c25de0 73b0806 c71b96e 73b0806 8c25de0 73b0806 c71b96e 8c25de0 44189a1 8c25de0 44189a1 c71b96e 44189a1 c71b96e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext
check_min_version('0.28.0.dev0')
def load_models(task_name, device):
if task_name == 'depth':
model_g = 'jingheya/lotus-depth-g-v1-0'
model_d = 'jingheya/lotus-depth-d-v1-1'
else:
model_g = 'jingheya/lotus-normal-g-v1-0'
model_d = 'jingheya/lotus-normal-d-v1-0'
dtype = torch.float16
pipe_g = LotusGPipeline.from_pretrained(
model_g,
torch_dtype=dtype,
)
pipe_d = LotusDPipeline.from_pretrained(
model_d,
torch_dtype=dtype,
)
pipe_g.to(device)
pipe_d.to(device)
pipe_g.set_progress_bar_config(disable=True)
pipe_d.set_progress_bar_config(disable=True)
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
return pipe_g, pipe_d
def infer_pipe(pipe, images_batch, task_name, seed, device):
if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(pipe.device.type)
with torch.no_grad():
with autocast_ctx:
# Convert list of images to tensor
images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch]
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
test_images = test_images / 127.5 - 1.0
test_images = test_images.to(device).type(torch.float16)
# Ensure task_emb matches expected dimensions
batch_size = test_images.shape[0]
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
task_emb = task_emb.repeat(batch_size, 1)
# Run inference
preds = pipe(
rgb_in=test_images,
prompt='',
num_inference_steps=1,
generator=generator,
output_type='np',
timesteps=[999],
task_emb=task_emb,
).images
# Post-process predictions
outputs = []
if task_name == 'depth':
for p in preds:
output_npy = p.mean(axis=-1)
output_color = colorize_depth_map(output_npy)
outputs.append(output_color)
else:
for p in preds:
output_npy = p
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
outputs.append(output_color)
return outputs
def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d):
output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device)
return output_d # Only returning depth outputs for this application |