csaybar commited on
Commit
747091e
1 Parent(s): 8d3effd

Upload 2 files

Browse files
Files changed (2) hide show
  1. diffuser/run.py +30 -0
  2. diffuser/utils.py +81 -0
diffuser/run.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import create_stable_diffusion_model, run_diffuser
2
+ import opensr_test
3
+ import matplotlib.pyplot as plt
4
+
5
+ # Load the model
6
+ model = create_stable_diffusion_model(device="cuda")
7
+
8
+ # Load the dataset
9
+ dataset = opensr_test.load("naip")
10
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
11
+
12
+ # Run the model
13
+ results = run_diffuser(
14
+ model=model,
15
+ lr=lr_dataset[5][:,0:64, 0:64],
16
+ hr=hr_dataset[5][:,0:256, 0:256],
17
+ device="cuda"
18
+ )
19
+
20
+ # Display the results
21
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
22
+ ax[0].imshow(results["lr"].transpose(1, 2, 0)/3000)
23
+ ax[0].set_title("LR")
24
+ ax[0].axis("off")
25
+ ax[1].imshow(results["sr"].transpose(1, 2, 0)/3000)
26
+ ax[1].set_title("SR")
27
+ ax[1].axis("off")
28
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
29
+ ax[2].set_title("HR")
30
+ plt.show()
diffuser/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LDMSuperResolutionPipeline
2
+ import numpy as np
3
+ import opensr_test
4
+ import torch
5
+ import pickle
6
+ from typing import Union
7
+
8
+
9
+ def create_stable_diffusion_model(
10
+ device: Union[str, torch.device] = "cuda"
11
+ ) -> LDMSuperResolutionPipeline:
12
+ """ Create the stable diffusion model
13
+
14
+ Returns:
15
+ LDMSuperResolutionPipeline: The model to use for
16
+ super resolution.
17
+ """
18
+ model_id = "CompVis/ldm-super-resolution-4x-openimages"
19
+ pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
20
+ pipeline = pipeline.to(device)
21
+ return pipeline
22
+
23
+ def run_diffuser(
24
+ model: LDMSuperResolutionPipeline,
25
+ lr: torch.Tensor,
26
+ hr: torch.Tensor,
27
+ device: Union[str, torch.device] = "cuda"
28
+ ) -> dict:
29
+ """ Run the model on the low resolution image
30
+
31
+ Args:
32
+ model (LDMSuperResolutionPipeline): The model to use
33
+ lr (torch.Tensor): The low resolution image
34
+ hr (torch.Tensor): The high resolution image
35
+ device (Union[str, torch.device], optional): The device
36
+ to use. Defaults to "cuda".
37
+
38
+ Returns:
39
+ dict: The results of the model
40
+ """
41
+
42
+ # move the images to the device
43
+ lr = (torch.from_numpy(lr[[3, 2, 1]]) / 2000).to(device).clamp(0, 1)
44
+
45
+ if lr.shape[1] == 121:
46
+ # add padding
47
+ lr = torch.nn.functional.pad(
48
+ lr[None],
49
+ pad=(3, 4, 3, 4),
50
+ mode='reflect'
51
+ ).squeeze()
52
+
53
+ # run the model
54
+ with torch.no_grad():
55
+ sr = model(lr[None], num_inference_steps=100, eta=1)
56
+ sr = torch.from_numpy(
57
+ np.array(sr.images[0])/255
58
+ ).permute(2,0,1).float()
59
+
60
+ # remove padding
61
+ sr = sr[:, 3*4:-4*4, 3*4:-4*4]
62
+ lr = lr[:, 3:-4, 3:-4]
63
+ else:
64
+ # run the model
65
+ with torch.no_grad():
66
+ sr = model(lr[None], num_inference_steps=100, eta=1)
67
+ sr = torch.from_numpy(
68
+ np.array(sr.images[0])/255
69
+ ).permute(2,0,1).float()
70
+
71
+ lr = (lr.cpu().numpy() * 2000).astype(np.uint16)
72
+ hr = ((hr[0:3] / 2000).clip(0, 1) * 2000).astype(np.uint16)
73
+ sr = (sr.cpu().numpy() * 2000).astype(np.uint16)
74
+
75
+ results = {
76
+ "lr": lr,
77
+ "hr": hr,
78
+ "sr": sr
79
+ }
80
+
81
+ return results