|
import pickle |
|
from typing import Union |
|
|
|
import numpy as np |
|
import opensr_test |
|
import torch |
|
from diffusers import LDMSuperResolutionPipeline |
|
|
|
|
|
def create_stable_diffusion_model( |
|
device: Union[str, torch.device] = "cuda" |
|
) -> LDMSuperResolutionPipeline: |
|
"""Create the stable diffusion model |
|
|
|
Returns: |
|
LDMSuperResolutionPipeline: The model to use for |
|
super resolution. |
|
""" |
|
model_id = "CompVis/ldm-super-resolution-4x-openimages" |
|
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id) |
|
pipeline = pipeline.to(device) |
|
return pipeline |
|
|
|
|
|
def run_diffuser( |
|
model: LDMSuperResolutionPipeline, |
|
lr: torch.Tensor, |
|
hr: torch.Tensor, |
|
device: Union[str, torch.device] = "cuda", |
|
) -> dict: |
|
"""Run the model on the low resolution image |
|
|
|
Args: |
|
model (LDMSuperResolutionPipeline): The model to use |
|
lr (torch.Tensor): The low resolution image |
|
hr (torch.Tensor): The high resolution image |
|
device (Union[str, torch.device], optional): The device |
|
to use. Defaults to "cuda". |
|
|
|
Returns: |
|
dict: The results of the model |
|
""" |
|
|
|
|
|
lr = (torch.from_numpy(lr[[3, 2, 1]]) / 2000).to(device).clamp(0, 1) |
|
|
|
if lr.shape[1] == 121: |
|
|
|
lr = torch.nn.functional.pad( |
|
lr[None], pad=(3, 4, 3, 4), mode="reflect" |
|
).squeeze() |
|
|
|
|
|
with torch.no_grad(): |
|
sr = model(lr[None], num_inference_steps=100, eta=1) |
|
sr = torch.from_numpy(np.array(sr.images[0]) / 255).permute(2, 0, 1).float() |
|
|
|
|
|
sr = sr[:, 3 * 4 : -4 * 4, 3 * 4 : -4 * 4] |
|
lr = lr[:, 3:-4, 3:-4] |
|
else: |
|
|
|
with torch.no_grad(): |
|
sr = model(lr[None], num_inference_steps=100, eta=1) |
|
sr = torch.from_numpy(np.array(sr.images[0]) / 255).permute(2, 0, 1).float() |
|
|
|
lr = (lr.cpu().numpy() * 2000).astype(np.uint16) |
|
hr = ((hr[0:3] / 2000).clip(0, 1) * 2000).astype(np.uint16) |
|
sr = (sr.cpu().numpy() * 2000).astype(np.uint16) |
|
|
|
results = {"lr": lr, "hr": hr, "sr": sr} |
|
|
|
return results |
|
|