File size: 1,740 Bytes
1a030c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from imgen3flip import weights_path, Model, ImageBatch, OPTS
import torch
import torchvision as TV
import torchvision.transforms.functional as VF
import sys
assert weights_path.exists(), "Model weights do not exist"
assert len(sys.argv) == 3, f"Usage: {
sys.argv[0]} <input-filename> <output-filename>"
input_filename = sys.argv[1]
output_filename = sys.argv[2]
assert input_filename != output_filename, f"Use different file names"
print("Loading the model")
model = Model()
model.load_state_dict(torch.load(weights_path))
print(f"Loading 8x8 input image from {input_filename}")
# read image and ditch alpha-channel if it presents
image = TV.io.read_image(input_filename)[:3]
# Convert range from 0..255 to 0.0..1.0
image = image / 255.0
assert image.shape[0] == 3, "RGB image expected"
# Convert C H W -> H W C
image = image.permute(1, 2, 0)
# Now add batch dimension(B=1): H W C -> 1 H W C
# We also specify H, W, C explicitly as model expect them to be 8x8x3
image = image.view(1, 8, 8, 3)
# Now construct batch that model uses
# Target and loss are not used in inference, as model code always calculates loss
dummy_target = torch.zeros(1, 64, 64, 3, **OPTS)
dummy_loss = torch.tensor(-1, **OPTS)
inference_batch = ImageBatch(
im8=image.to(**OPTS),
im64=dummy_target,
loss=dummy_loss)
result = model(inference_batch)
# Now convert image to PIL format so we can save it
new_image = result.im64.detach().float().cpu()
# new_image: 1 H W C -> H W C
new_image = new_image[0]
# new_image: H W C -> C H W
new_image = new_image.permute(2, 0, 1)
assert new_image.shape == (3, 64, 64)
img = VF.to_pil_image(new_image)
# Save
print(f"Writing {img.height}x{img.width} image to {output_filename}")
img.save(output_filename)
|