from PIL import Image from torchvision.transforms import PILToTensor ip = Image.open('images/demo.png') lr = PILToTensor()(ip).unsqueeze(0).permute(0, 2, 3, 1) print(lr.shape) lr = lr.squeeze(0).permute(0, 1, 2).numpy() # lr = PILToTensor()(ip).permute(1, 2, 0) # lr = lr.unsqueeze(0).permute(0, 3, 1, 2).numpy() print(lr.shape)