File size: 3,210 Bytes
eb5b895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import sys
import numpy as np
import torch
import argparse
import logging
import torch.utils
from PIL import Image
from torch.autograd import Variable
from model import Finetunemodel
from multi_read_data import DataLoader
from thop import profile
root_dir = os.path.abspath('../')
sys.path.append(root_dir)
parser = argparse.ArgumentParser("ZERO-IG")
parser.add_argument('--data_path_test_low', type=str, default='./data',
help='location of the data corpus')
parser.add_argument('--save', type=str,
default='./results/',
help='location of the data corpus')
parser.add_argument('--model_test', type=str,
default='./model',
help='location of the data corpus')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
args = parser.parse_args()
save_path = args.save
os.makedirs(save_path, exist_ok=True)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
mertic = logging.FileHandler(os.path.join(args.save, 'log.txt'))
mertic.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(mertic)
logging.info("train file name = %s", os.path.split(__file__))
TestDataset = DataLoader(img_dir=args.data_path_test_low,task='test')
test_queue = torch.utils.data.DataLoader(TestDataset, batch_size=1, pin_memory=True, num_workers=0, shuffle=False)
def save_images(tensor):
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
return im
def calculate_model_parameters(model):
return sum(p.numel() for p in model.parameters())
def calculate_model_flops(model, input_tensor):
flops, _ = profile(model, inputs=(input_tensor,))
flops_in_gigaflops = flops / 1e9 # Convert FLOPs to gigaflops (G)
return flops_in_gigaflops
def main():
if not torch.cuda.is_available():
print('no gpu device available')
sys.exit(1)
model = Finetunemodel(args.model_test)
model = model.cuda()
model.eval()
# Calculate model size
total_params = calculate_model_parameters(model)
print("Total number of parameters: ", total_params)
for p in model.parameters():
p.requires_grad = False
with torch.no_grad():
for _, (input, img_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
input_name = img_name[0].split('/')[-1].split('.')[0]
enhance,output = model(input)
input_name = '%s' % (input_name)
enhance=save_images(enhance)
output = save_images(output)
os.makedirs(args.save + '/result', exist_ok=True)
Image.fromarray(output).save(args.save + '/result/' +input_name + '_denoise' + '.png', 'PNG')
Image.fromarray(enhance).save(args.save + '/result/'+ input_name + '_enhance' + '.png', 'PNG')
torch.set_grad_enabled(True)
if __name__ == '__main__':
main()
|