File size: 1,740 Bytes
1a030c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from imgen3flip import weights_path, Model, ImageBatch, OPTS
import torch
import torchvision as TV
import torchvision.transforms.functional as VF
import sys


assert weights_path.exists(), "Model weights do not exist"

assert len(sys.argv) == 3, f"Usage: {
    sys.argv[0]} <input-filename> <output-filename>"

input_filename = sys.argv[1]
output_filename = sys.argv[2]

assert input_filename != output_filename, f"Use different file names"

print("Loading the model")
model = Model()
model.load_state_dict(torch.load(weights_path))

print(f"Loading 8x8 input image from {input_filename}")
# read image and ditch alpha-channel if it presents
image = TV.io.read_image(input_filename)[:3]
# Convert range from 0..255 to 0.0..1.0
image = image / 255.0
assert image.shape[0] == 3, "RGB image expected"
# Convert C H W -> H W C
image = image.permute(1, 2, 0)
# Now add batch dimension(B=1): H W C -> 1 H W C
# We also specify H, W, C explicitly as model expect them to be 8x8x3
image = image.view(1, 8, 8, 3)

# Now construct batch that model uses
# Target and loss are not used in inference, as model code always calculates loss
dummy_target = torch.zeros(1, 64, 64, 3, **OPTS)
dummy_loss = torch.tensor(-1, **OPTS)
inference_batch = ImageBatch(
    im8=image.to(**OPTS),
    im64=dummy_target,
    loss=dummy_loss)
result = model(inference_batch)

# Now convert image to PIL format so we can save it
new_image = result.im64.detach().float().cpu()
# new_image: 1 H W C -> H W C
new_image = new_image[0]
# new_image: H W C -> C H W
new_image = new_image.permute(2, 0, 1)
assert new_image.shape == (3, 64, 64)
img = VF.to_pil_image(new_image)
# Save
print(f"Writing {img.height}x{img.width} image to {output_filename}")
img.save(output_filename)