File size: 3,894 Bytes
3d5e231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b98e7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------

import os
import random
import urllib
import hashlib
import tarfile
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def download(url: str, root: str) -> str:
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)
    pathname = filename[:-len('.tar.gz')]

    expected_md5 = url.split("/")[-2]
    download_target = os.path.join(root, filename)
    result_path = os.path.join(root, pathname)

    if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
        return result_path

    with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
        with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
                  unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
        raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')

    with tarfile.open(download_target, 'r:gz') as f:
        pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
        for member in pbar:
            pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
            f.extract(member=member, path=root)

    return result_path


def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
    if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
        return download(url_or_path, root)
    return url_or_path


def images_to_numpy(tensor):
    generated = tensor.data.cpu().numpy().transpose(1,2,0)
    generated[generated < -1] = -1
    generated[generated > 1] = 1
    generated = (generated + 1) / 2 * 255
    return generated.astype('uint8')


def save_image(ground_truth, images, out_dir, batch_idx):

    for i, im in enumerate(images):
        if len(im.shape) == 3:
            plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im)
        else:
            bs = im.shape[0]
            # plt.imsave()
            for j in range(bs):
                plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j])


    # print("Ground truth Images shape: ", ground_truth.shape, len(images))

    # images = vutils.make_grid(images, nrow=ground_truth.shape[0])
    # images = images_to_numpy(images)
    #
    # if ground_truth is not None:
    #     ground_truth = vutils.make_grid(ground_truth, 5)
    #     ground_truth = images_to_numpy(ground_truth)
    #     print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape)
    #     images = np.concatenate([ground_truth, images], axis=0)
    #
    # output = Image.fromarray(images)
    # output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx))

    # if texts is not None:
    #     fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w')
    #     for idx in range(images.shape[0]):
    #         fid.write(str(idx) + '--------------------------------------------------------\n')
    #         for i in range(len(texts)):
    #             fid.write(texts[i][idx] + '\n')
    #         fid.write('\n\n')
    #     fid.close()
    return