File size: 3,253 Bytes
9842c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Split images into small patches and insert them into sqlite db.  Reading and Inserting speeds are much better than
Ubuntu's (18.04) file system when the number of patches is larger than 20k. And it has smaller size than using h5 format

Recommend to check or filter out small size patches as their content vary little. 128x128 seems better than 64x64.


"""
import sqlite3
from torch.utils.data import DataLoader
from tqdm import trange
from Dataloader import Image2Sqlite

conn = sqlite3.connect("dataset/image_yandere.db")
cursor = conn.cursor()

with conn:
    cursor.execute("PRAGMA SYNCHRONOUS = OFF")

table_name = "train_images_size_128_noise_1_rgb"
lr_col = "lr_img"
hr_col = "hr_img"

with conn:
    conn.execute(
        f"CREATE TABLE IF NOT EXISTS {table_name} ({lr_col} BLOB, {hr_col} BLOB)"
    )

dat = Image2Sqlite(
    img_folder="./dataset/yande.re_test_shrink",
    patch_size=256,
    shrink_size=2,
    noise_level=1,
    down_sample_method=None,
    color_mod="RGB",
    dummy_len=None,
)
print(f"Total images {len(dat)}")

img_dat = DataLoader(dat, num_workers=6, batch_size=6, shuffle=True)

num_batches = 20
for i in trange(num_batches):
    bulk = []
    for lrs, hrs in img_dat:
        patches = [(lrs[i], hrs[i]) for i in range(len(lrs))]
        # patches = [(lrs[i], hrs[i]) for i in range(len(lrs)) if len(lrs[i]) > 14000]

        bulk.extend(patches)

    bulk = [
        i for i in bulk if len(i[0]) > 15000
    ]  # for 128x128, 14000 is fair. Around 20% of patches are filtered out
    cursor.executemany(
        f"INSERT INTO {table_name}({lr_col}, {hr_col}) VALUES (?,?)", bulk
    )
    conn.commit()

cursor.execute(f"select max(rowid) from {table_name}")
print(cursor.fetchall())
conn.commit()
# +++++++++++++++++++++++++++++++++++++
#           Used for Create Test Database
# -------------------------------------

# cursor.execute(f"SELECT ROWID FROM {table_name} ORDER BY LENGTH({lr_col}) DESC LIMIT 400")
# rowdis = cursor.fetchall()
# rowdis = ",".join([str(i[0]) for i in rowdis])
#
# cursor.execute(f"DELETE FROM {table_name} WHERE ROWID NOT IN ({rowdis})")
# conn.commit()
# cursor.execute("vacuum")
#
# cursor.execute("""
# CREATE TABLE IF NOT EXISTS train_images_size_128_noise_1_rgb_small AS
# SELECT *
# FROM train_images_size_128_noise_1_rgb
# WHERE length(lr_img) < 14000;
# """)
#
# cursor.execute("""
# DELETE
# FROM train_images_size_128_noise_1_rgb
# WHERE length(lr_img) < 14000;
# """)

# reset index
cursor.execute("VACUUM")
conn.commit()

# +++++++++++++++++++++++++++++++++++++
#           check image size
# -------------------------------------
#

from PIL import Image
import io

cursor.execute(
    f"""
    select {hr_col} from {table_name} 
    ORDER BY LENGTH({hr_col}) desc 
    limit 100
"""
)
# WHERE LENGTH({lr_col}) BETWEEN 14000 AND 16000

# small = cursor.fetchall()
# print(len(small))
for idx, i in enumerate(cursor):
    img = Image.open(io.BytesIO(i[0]))
    img.save(f"dataset/check/{idx}.png")

# +++++++++++++++++++++++++++++++++++++
#           Check Image Variance
# -------------------------------------

import pandas as pd
import matplotlib.pyplot as plt

dat = pd.read_sql(f"SELECT length({lr_col}) from {table_name}", conn)
dat.hist(bins=20)
plt.show()