vittoriopippi commited on
Commit
fa0f216
·
1 Parent(s): 434bf7c

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ taylor_swift.png
2
+ test.py
3
+ *.pyc
Groundtruth/gan.iam.test.gt.filter27 ADDED
The diff for this file is too large to render. See raw diff
 
Groundtruth/gan.iam.tr_va.gt.filter27 ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,99 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Handwritten Text Generation from Visual Archetypes ++
2
+
3
+ This repository includes the code for training the VATr++ Styled Handwritten Text Generation model.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ conda create --name vatr python=3.9
9
+ conda activate vatr
10
+ conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
11
+ git clone https://github.com/aimagelab/VATr.git && cd VATr
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ [This folder](https://drive.google.com/drive/folders/13rJhjl7VsyiXlPTBvnp1EKkKEhckLalr?usp=sharing) contains the regular IAM dataset `IAM-32.pickle` and the modified version with attached punctuation marks `IAM-32-pa.pickle`.
16
+ The folder also contains the synthetically pretrained weights for the encoder `resnet_18_pretrained.pth`.
17
+ Please download these files and place them into the `files` folder.
18
+
19
+ ## Training
20
+
21
+ To train the regular VATr model, use the following command. This uses the default settings from the paper.
22
+
23
+ ```bash
24
+ python train.py
25
+ ```
26
+
27
+ Useful arguments:
28
+ ```bash
29
+ python train.py
30
+ --feat_model_path PATH # path to the pretrained resnet 18 checkpoint. By default this is the synthetically pretrained model
31
+ --is_cycle # use style cycle loss for training
32
+ --dataset DATASET # dataset to use. Default IAM
33
+ --resume # resume training from the last checkpoint with the same name
34
+ --wandb # use wandb for logging
35
+ ```
36
+
37
+ Use the following arguments to apply full VATr++ training
38
+ ```bash
39
+ python train.py
40
+ --d-crop-size 64 128 # Randomly crop input to discriminator to width 64 to 128
41
+ --text-augment-strength 0.4 # Text augmentation for adding more rare characters
42
+ --file-suffix pa # Use the punctuation attached version of IAM
43
+ --augment-ocr # Augment the real images used to train the OCR model
44
+ ```
45
+
46
+ ### Pretraining dataset
47
+ The model `resnet_18_pretrained.pth` was pretrained by using this dataset: [Font Square](https://github.com/aimagelab/font_square)
48
+
49
+
50
+ ## Generate Styled Handwritten Text Images
51
+
52
+ We added some utility to generate handwritten text images using the trained model. These are used as follows:
53
+
54
+ ```bash
55
+ python generate.py [ACTION] --checkpoint files/vatrpp.pth
56
+ ```
57
+
58
+ The following actions are available with their respective arguments.
59
+
60
+ ### Custom Author
61
+
62
+ Generate the given text for a custom author.
63
+
64
+ ```bash
65
+ text --text STRING # String to generate
66
+ --text-path PATH # Optional path to text file
67
+ --output PATH # Optional output location, default: files/output.png
68
+ --style-folder PATH # Optional style folder containing writer samples, default: 'files/style_samples/00'
69
+ ```
70
+ Style samples for the author are needed. These can be automatically generated from an image of a page using `create_style_sample.py`.
71
+ ```bash
72
+ python create_style_sample.py --input-image PATH # Path of the image to extract the style samples from.
73
+ --output-folder PATH # Folder where the style samples should be saved
74
+ ```
75
+
76
+ ### All Authors
77
+
78
+ Generate some text for all authors of IAM. The output is saved to `saved_images/author_samples/`
79
+
80
+ ```bash
81
+ authors --test-set # Generate authors of test set, otherwise training set is generated
82
+ --checkpoint PATH # Checkpoint used to generate text, files/vatr.pth by default
83
+ --align # Detect the bottom lines for each word and align them
84
+ --at-once # Generate the whole sentence at once instead of word-by-word
85
+ --output-style # Also save the style images used to generate the words
86
+ ```
87
+
88
+ ### Evaluation Images
89
+
90
+ ```bash
91
+ fid --target_dataset_path PATH # dataset file for which the test set will be generated
92
+ --dataset-path PATH # dataset file from which style samples will be taken, for example the attached punctuation
93
+ --output PATH # where to save the images, default is saved_images/fid
94
+ --checkpoint PATH # Checkpoint used to generate text, files/vatr.pth by default
95
+ --all-epochs # Generate evaluation images for all saved epochs available (checkpoint has to be a folder)
96
+ --fake-only # Only output fake images, no ground truth
97
+ --test-only # Only generate test set, not train set
98
+ --long-tail # Only generate words containing long tail characters
99
+ ```
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_noise": false,
3
+ "alphabet": "Only thewigsofrcvdampbkuq.A-210xT5'MDL,RYHJ\"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%",
4
+ "architectures": [
5
+ "VATrPP"
6
+ ],
7
+ "augment_ocr": false,
8
+ "batch_size": 8,
9
+ "corpus": "standard",
10
+ "d_crop_size": null,
11
+ "d_lr": 1e-05,
12
+ "dataset": "IAM",
13
+ "device": "cuda",
14
+ "english_words_path": "files/english_words.txt",
15
+ "epochs": 100000,
16
+ "feat_model_path": "files/resnet_18_pretrained.pth",
17
+ "file_suffix": null,
18
+ "g_lr": 5e-05,
19
+ "img_height": 32,
20
+ "is_cycle": false,
21
+ "label_encoder": "default",
22
+ "model_type": "emuru",
23
+ "no_ocr_loss": false,
24
+ "no_writer_loss": false,
25
+ "num_examples": 15,
26
+ "num_words": 3,
27
+ "num_workers": 0,
28
+ "num_writers": 339,
29
+ "ocr_lr": 5e-05,
30
+ "query_input": "unifont",
31
+ "resolution": 16,
32
+ "save_model": 5,
33
+ "save_model_history": 500,
34
+ "save_model_path": "saved_models",
35
+ "seed": 742,
36
+ "special_alphabet": "\u0391\u03b1\u0392\u03b2\u0393\u03b3\u0394\u03b4\u0395\u03b5\u0396\u03b6\u0397\u03b7\u0398\u03b8\u0399\u03b9\u039a\u03ba\u039b\u03bb\u039c\u03bc\u039d\u03bd\u039e\u03be\u039f\u03bf\u03a0\u03c0\u03a1\u03c1\u03a3\u03c3\u03c2\u03a4\u03c4\u03a5\u03c5\u03a6\u03c6\u03a7\u03c7\u03a8\u03c8\u03a9\u03c9",
37
+ "tag": "debug",
38
+ "text_aug_type": "proportional",
39
+ "text_augment_strength": 0.0,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.46.2",
42
+ "vocab_size": 80,
43
+ "w_lr": 5e-05,
44
+ "wandb": false,
45
+ "writer_loss_weight": 1.0
46
+ }
configuration_vatrpp.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class VATrPPConfig(PretrainedConfig):
4
+ model_type = "emuru"
5
+
6
+ def __init__(self,
7
+ feat_model_path='files/resnet_18_pretrained.pth',
8
+ label_encoder='default',
9
+ save_model_path='saved_models',
10
+ dataset='IAM',
11
+ english_words_path='files/english_words.txt',
12
+ wandb=False,
13
+ no_writer_loss=False,
14
+ writer_loss_weight=1.0,
15
+ no_ocr_loss=False,
16
+ img_height=32,
17
+ resolution=16,
18
+ batch_size=8,
19
+ num_examples=15,
20
+ num_writers=339,
21
+ alphabet='Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%',
22
+ special_alphabet='ΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω',
23
+ g_lr=0.00005,
24
+ d_lr=0.00001,
25
+ w_lr=0.00005,
26
+ ocr_lr=0.00005,
27
+ epochs=100000,
28
+ num_workers=0,
29
+ seed=742,
30
+ num_words=3,
31
+ is_cycle=False,
32
+ add_noise=False,
33
+ save_model=5,
34
+ save_model_history=500,
35
+ tag='debug',
36
+ device='cuda',
37
+ query_input='unifont',
38
+ corpus="standard",
39
+ text_augment_strength=0.0,
40
+ text_aug_type="proportional",
41
+ file_suffix=None,
42
+ augment_ocr=False,
43
+ d_crop_size=None,
44
+ **kwargs):
45
+ super().__init__(**kwargs)
46
+ self.feat_model_path = feat_model_path
47
+ self.label_encoder = label_encoder
48
+ self.save_model_path = save_model_path
49
+ self.dataset = dataset
50
+ self.english_words_path = english_words_path
51
+ self.wandb = wandb
52
+ self.no_writer_loss = no_writer_loss
53
+ self.writer_loss_weight = writer_loss_weight
54
+ self.no_ocr_loss = no_ocr_loss
55
+ self.img_height = img_height
56
+ self.resolution = resolution
57
+ self.batch_size = batch_size
58
+ self.num_examples = num_examples
59
+ self.num_writers = num_writers
60
+ self.alphabet = alphabet
61
+ self.special_alphabet = special_alphabet
62
+ self.g_lr = g_lr
63
+ self.d_lr = d_lr
64
+ self.w_lr = w_lr
65
+ self.ocr_lr = ocr_lr
66
+ self.epochs = epochs
67
+ self.num_workers = num_workers
68
+ self.seed = seed
69
+ self.num_words = num_words
70
+ self.is_cycle = is_cycle
71
+ self.add_noise = add_noise
72
+ self.save_model = save_model
73
+ self.save_model_history = save_model_history
74
+ self.tag = tag
75
+ self.device = device
76
+ self.query_input = query_input
77
+ self.corpus = corpus
78
+ self.text_augment_strength = text_augment_strength
79
+ self.text_aug_type = text_aug_type
80
+ self.file_suffix = file_suffix
81
+ self.augment_ocr = augment_ocr
82
+ self.d_crop_size = d_crop_size
corpora_english/brown-azAZ.tr ADDED
The diff for this file is too large to render. See raw diff
 
corpora_english/in_vocab.subset.tro.37 ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accents
2
+ fifty
3
+ gross
4
+ Tea
5
+ whom
6
+ renamed
7
+ Heaven
8
+ Harry
9
+ arrange
10
+ captain
11
+ why
12
+ Father
13
+ beaten
14
+ Bar
15
+ base
16
+ creamy
17
+ About
18
+ Allies
19
+ sound
20
+ farmers
21
+ anyone
22
+ steel
23
+ Mary
24
+ used
25
+ fever
26
+ looking
27
+ lately
28
+ returns
29
+ humans
30
+ finals
31
+ beyond
32
+ lots
33
+ waiting
34
+ cited
35
+ measure
36
+ posse
37
+ blow
38
+ blonde
39
+ twice
40
+ Having
41
+ compels
42
+ rooms
43
+ cocked
44
+ virtual
45
+ dying
46
+ tons
47
+ Travel
48
+ idea
49
+ gripped
50
+ Act
51
+ reign
52
+ moods
53
+ altered
54
+ sample
55
+ Soviet
56
+ thick
57
+ enigma
58
+ here
59
+ egghead
60
+ Public
61
+ Bryan
62
+ porous
63
+ estate
64
+ guilty
65
+ Caught
66
+ Lucas
67
+ observe
68
+ mouth
69
+ pricked
70
+ obscure
71
+ casual
72
+ take
73
+ home
74
+ amber
75
+ weekend
76
+ forming
77
+ aid
78
+ outlook
79
+ uniting
80
+ But
81
+ earnest
82
+ bear
83
+ news
84
+ sparked
85
+ merrily
86
+ extreme
87
+ North
88
+ damned
89
+ big
90
+ bosses
91
+ context
92
+ easily
93
+ took
94
+ hurried
95
+ Gene
96
+ due
97
+ deserve
98
+ cult
99
+ leisure
100
+ critics
101
+ parish
102
+ Music
103
+ charge
104
+ grey
105
+ Privy
106
+ Fred
107
+ massive
108
+ others
109
+ shirt
110
+ average
111
+ warning
112
+ Tuesday
113
+ locked
114
+ possess
corpora_english/oov.common_words ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ planets
2
+ lips
3
+ varies
4
+ impact
5
+ skips
6
+ Gold
7
+ maple
8
+ voyager
9
+ noisy
10
+ stick
11
+ forums
12
+ drafts
13
+ crimson
14
+ sever
15
+ rackets
16
+ sexy
17
+ humming
18
+ cheated
19
+ lick
20
+ grades
21
+ heroic
22
+ Clever
23
+ foul
24
+ mood
25
+ warrior
26
+ Morning
27
+ poetic
28
+ nodding
29
+ certify
30
+ reviews
31
+ mosaics
32
+ senders
33
+ Isle
34
+ Lied
35
+ sand
36
+ Weight
37
+ writer
38
+ trusts
39
+ slot
40
+ eaten
41
+ squares
42
+ lists
43
+ vary
44
+ witches
45
+ compose
46
+ demons
47
+ therapy
48
+ focus
49
+ sticks
50
+ Whose
51
+ bumped
52
+ visibly
53
+ redeem
54
+ arsenal
55
+ lunatic
56
+ Similar
57
+ Bug
58
+ adheres
59
+ trail
60
+ robbing
61
+ Whisky
62
+ super
63
+ screwed
64
+ Flower
65
+ salads
66
+ Glow
67
+ Vapor
68
+ Married
69
+ recieve
70
+ handle
71
+ push
72
+ card
73
+ skiing
74
+ lotus
75
+ cloud
76
+ windy
77
+ monkey
78
+ virus
79
+ thunder
corpora_english/oov_words.txt ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ planets
2
+ lips
3
+ varies
4
+ impact
5
+ skips
6
+ Gold
7
+ maple
8
+ voyager
9
+ noisy
10
+ stick
11
+ forums
12
+ drafts
13
+ crimson
14
+ sever
15
+ rackets
16
+ sexy
17
+ humming
18
+ cheated
19
+ lick
20
+ grades
21
+ heroic
22
+ Clever
23
+ foul
24
+ mood
25
+ warrior
26
+ Morning
27
+ poetic
28
+ nodding
29
+ certify
30
+ reviews
31
+ mosaics
32
+ senders
33
+ Isle
34
+ Lied
35
+ sand
36
+ Weight
37
+ writer
38
+ trusts
39
+ slot
40
+ eaten
41
+ squares
42
+ lists
43
+ vary
44
+ witches
45
+ compose
46
+ demons
47
+ therapy
48
+ focus
49
+ sticks
50
+ Whose
51
+ bumped
52
+ visibly
53
+ redeem
54
+ arsenal
55
+ lunatic
56
+ Similar
57
+ Bug
58
+ adheres
59
+ trail
60
+ robbing
61
+ Whisky
62
+ super
63
+ screwed
64
+ Flower
65
+ salads
66
+ Glow
67
+ Vapor
68
+ Married
69
+ recieve
70
+ handle
71
+ push
72
+ card
73
+ skiing
74
+ lotus
75
+ cloud
76
+ windy
77
+ monkey
78
+ virus
79
+ thunder
80
+ Keegan
81
+ purling
82
+ Orpheus
83
+ Prence
84
+ Yin
85
+ Kansas
86
+ jowls
87
+ Alabama
88
+ Szold
89
+ Chou
90
+ Orange
91
+ suspend
92
+ barred
93
+ deceit
94
+ reward
95
+ soy
96
+ Vail
97
+ lad
98
+ Loesser
99
+ Hutton
100
+ jerks
101
+ yelling
102
+ Heywood
103
+ sacker
104
+ comest
105
+ tense
106
+ par
107
+ fiend
108
+ Soiree
109
+ voted
110
+ Putting
111
+ pansy
112
+ doormen
113
+ mayor
114
+ Owens
115
+ noting
116
+ pauses
117
+ USP
118
+ crudely
119
+ grooved
120
+ furor
121
+ ignited
122
+ kittens
123
+ broader
124
+ slang
125
+ ballets
126
+ quacked
127
+ Paulus
128
+ Castles
129
+ upswing
130
+ dabbled
131
+ Animals
132
+ Kidder
133
+ Writers
134
+ laces
135
+ bled
136
+ scoped
137
+ yield
138
+ scoured
139
+ Schenk
140
+ Wratten
141
+ Menfolk
142
+ foamy
143
+ scratch
144
+ minced
145
+ nudged
146
+ Seats
147
+ Judging
148
+ Turbine
149
+ Strict
150
+ whined
151
+ crupper
152
+ Dussa
153
+ finned
154
+ voter
155
+ Jacobs
156
+ calmly
157
+ hip
158
+ clubs
159
+ quintet
160
+ blunts
161
+ Grazie
162
+ Barton
163
+ NAB
164
+ specie
165
+ Fonta
166
+ narrow
167
+ Swan
168
+ denials
169
+ Rawson
170
+ potato
171
+ Choral
172
+ diverse
173
+ Educate
174
+ unities
175
+ Ferry
176
+ Bonner
177
+ manuals
178
+ NAIR
179
+ imputed
180
+ initial
181
+ wallet
182
+ Sesame
183
+ maroon
184
+ Related
185
+ Quiney
186
+ Monster
187
+ brainy
188
+ Nolan
189
+ Thrifty
190
+ Tel
191
+ Ye
192
+ Sumter
193
+ Bonnet
194
+ sheepe
195
+ nagged
196
+ ribbing
197
+ hunt
198
+ AA
199
+ Pohly
200
+ triol
201
+ saws
202
+ popped
203
+ aloof
204
+ Ceramic
205
+ thong
206
+ typed
207
+ broadly
208
+ Figures
209
+ riddle
210
+ Otis
211
+ Sainted
212
+ upbeat
213
+ Getting
214
+ hisself
215
+ junta
216
+ Labans
217
+ starter
218
+ coward
219
+ Anthea
220
+ hurlers
221
+ Dervish
222
+ Turin
223
+ oud
224
+ tyranny
225
+ Rotary
226
+ Veneto
227
+ pulls
228
+ bowl
229
+ utopias
230
+ auburn
231
+ osmotic
232
+ myrtle
233
+ furrow
234
+ laws
235
+ Uh
236
+ Hodges
237
+ Wilde
238
+ Neck
239
+ snaked
240
+ decorum
241
+ edema
242
+ Dunston
243
+ clinics
244
+ Abide
245
+ Dover
246
+ voltaic
247
+ Modern
248
+ Farr
249
+ thaw
250
+ moi
251
+ leaning
252
+ wedlock
253
+ Carson
254
+ star
255
+ Hymn
256
+ Stack
257
+ genes
258
+ Shayne
259
+ Moune
260
+ slipped
261
+ legatee
262
+ coerced
263
+ Gates
264
+ pulse
265
+ Granny
266
+ bat
267
+ Fruit
268
+ Cadesi
269
+ Tee
270
+ Dreiser
271
+ Getz
272
+ Ways
273
+ cogs
274
+ hydrous
275
+ sweep
276
+ quarrel
277
+ mobcaps
278
+ slash
279
+ throats
280
+ Royaux
281
+ cafes
282
+ crusher
283
+ rusted
284
+ Eskimo
285
+ slatted
286
+ pallet
287
+ yelps
288
+ slanted
289
+ confide
290
+ Gomez
291
+ untidy
292
+ Sigmund
293
+ Marine
294
+ roll
295
+ NRL
296
+ Dukes
297
+ tumours
298
+ LP
299
+ turtles
300
+ audible
301
+ Woodrow
302
+ retreat
303
+ Orders
304
+ Conlow
305
+ hobby
306
+ skin
307
+ tally
308
+ frosted
309
+ drowned
310
+ wedged
311
+ queen
312
+ poised
313
+ eluded
314
+ Letter
315
+ ticking
316
+ kill
317
+ rancor
318
+ Plant
319
+ Brandel
320
+ Willows
321
+ riddles
322
+ carven
323
+ Spiller
324
+ yen
325
+ jerky
326
+ tenure
327
+ daubed
328
+ Serves
329
+ pimpled
330
+ ACTH
331
+ ruh
332
+ afield
333
+ suffuse
334
+ muffins
335
+ Miners
336
+ Cabrini
337
+ weakly
338
+ upriver
339
+ Newsom
340
+ Meeker
341
+ weed
342
+ fiscal
343
+ Diane
344
+ Errors
345
+ Mig
346
+ biz
347
+ Drink
348
+ chop
349
+ Bumbry
350
+ Babin
351
+ optimum
352
+ Leyden
353
+ enrage
354
+ induces
355
+ newel
356
+ trim
357
+ bolts
358
+ frog
359
+ cinder
360
+ Lo
361
+ clobber
362
+ Mennen
363
+ Othon
364
+ Ocean
365
+ jerking
366
+ engine
367
+ Belasco
368
+ hero
369
+ flora
370
+ Injuns
371
+ Rico
372
+ Gary
373
+ snake
374
+ hating
375
+ Suggs
376
+ booze
377
+ Lescaut
378
+ Molard
379
+ startle
380
+ Aggie
381
+ lengthy
382
+ Shoals
383
+ ideals
384
+ Zen
385
+ stem
386
+ noon
387
+ hoes
388
+ Seafood
389
+ yuh
390
+ Mostly
391
+ seeds
392
+ bestow
393
+ acetate
394
+ jokers
395
+ waning
396
+ volumes
397
+ ein
398
+ Rich
399
+ Galt
400
+ pasted
create_style_sample.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+ import cv2
5
+ from util.vision import get_page, get_words
6
+
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser()
10
+
11
+ parser.add_argument("--input-image", type=str, required=True)
12
+ parser.add_argument("--output-folder", type=str, required=True, default='files/style_samples/00')
13
+
14
+ args = parser.parse_args()
15
+
16
+ image = cv2.imread(args.input_image)
17
+ image = cv2.resize(image, (image.shape[1], image.shape[0]))
18
+ result = get_page(image)
19
+ words, _ = get_words(result)
20
+
21
+ output_path = args.output_folder
22
+ if not os.path.exists(output_path):
23
+ os.mkdir(output_path)
24
+ for i, word in enumerate(words):
25
+ cv2.imwrite(os.path.join(output_path, f"word{i}.png"), word)
data/create_data.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import json
3
+ import os
4
+ import pickle
5
+ import random
6
+ from collections import defaultdict
7
+
8
+ import PIL
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+
14
+ TO_MERGE = {
15
+ '.': 'left',
16
+ ',': 'left',
17
+ '!': 'left',
18
+ '?': 'left',
19
+ '(': 'right',
20
+ ')': 'left',
21
+ '\"': 'random',
22
+ "\'": 'random',
23
+ ":": 'left',
24
+ ";": 'left',
25
+ "-": 'random'
26
+ }
27
+
28
+ FILTER_ERR = False
29
+
30
+
31
+ def resize(image, size):
32
+ image_pil = Image.fromarray(image.astype('uint8'), 'L')
33
+ image_pil = image_pil.resize(size)
34
+ return np.array(image_pil)
35
+
36
+
37
+ def get_author_ids(base_folder: str):
38
+ with open(os.path.join(base_folder, "gan.iam.tr_va.gt.filter27"), 'r') as f:
39
+ training_authors = [line.split(",")[0] for line in f]
40
+ training_authors = set(training_authors)
41
+
42
+ with open(os.path.join(base_folder, "gan.iam.test.gt.filter27"), 'r') as f:
43
+ test_authors = [line.split(",")[0] for line in f]
44
+ test_authors = set(test_authors)
45
+
46
+ assert len(training_authors.intersection(test_authors)) == 0
47
+
48
+ return training_authors, test_authors
49
+
50
+
51
+ class IAMImage:
52
+ def __init__(self, image: np.array, label: str, image_id: int, line_id: str, bbox: list = None, iam_image_id: str = None):
53
+ self.image = image
54
+ self.label = label
55
+ self.image_id = image_id
56
+ self.line_id = line_id
57
+ self.iam_image_id = iam_image_id
58
+ self.has_bbox = False
59
+ if bbox is not None:
60
+ self.has_bbox = True
61
+ self.x, self.y, self.w, self.h = bbox
62
+
63
+ def merge(self, other: 'IAMImage'):
64
+ global MERGER_COUNT
65
+ assert self.has_bbox, "IAM image has no bounding box information"
66
+ y = min(self.y, other.y)
67
+ h = max(other.y + other.h, self.y + self.h) - y
68
+
69
+ x = min(self.x, other.x)
70
+ w = max(self.x + self.w, other.x + other.w) - x
71
+
72
+ new_image = np.ones((h, w), dtype=self.image.dtype) * 255
73
+
74
+ anchor_x = self.x - x
75
+ anchor_y = self.y - y
76
+ new_image[anchor_y:anchor_y + self.h, anchor_x:anchor_x + self.w] = self.image
77
+
78
+ anchor_x = other.x - x
79
+ anchor_y = other.y - y
80
+ new_image[anchor_y:anchor_y + other.h, anchor_x:anchor_x + other.w] = other.image
81
+
82
+ if other.x - (self.x + self.w) > 50:
83
+ new_label = self.label + " " + other.label
84
+ else:
85
+ new_label = self.label + other.label
86
+ new_id = self.image_id
87
+ new_bbox = [x, y, w, h]
88
+
89
+ new_iam_image_id = self.iam_image_id if len(self.label) > len(other.label) else other.iam_image_id
90
+ return IAMImage(new_image, new_label, new_id, self.line_id, new_bbox, iam_image_id=new_iam_image_id)
91
+
92
+
93
+ def read_iam_lines(base_folder: str) -> dict:
94
+ form_to_author = {}
95
+ with open(os.path.join(base_folder, "forms.txt"), 'r') as f:
96
+ for line in f:
97
+ if not line.startswith("#"):
98
+ form, author, *_ = line.split(" ")
99
+ form_to_author[form] = author
100
+
101
+ training_authors, test_authors = get_author_ids(base_folder)
102
+
103
+ dataset_dict = {
104
+ 'train': defaultdict(list),
105
+ 'test': defaultdict(list),
106
+ 'other': defaultdict(list)
107
+ }
108
+
109
+ image_count = 0
110
+
111
+ with open(os.path.join(base_folder, "sentences.txt"), 'r') as f:
112
+ for line in f:
113
+ if not line.startswith("#"):
114
+ line_id, _, ok, *_, label = line.rstrip().split(" ")
115
+ form_id = "-".join(line_id.split("-")[:2])
116
+ author_id = form_to_author[form_id]
117
+
118
+ if ok != 'ok' and FILTER_ERR:
119
+ continue
120
+
121
+ line_label = ""
122
+ for word in label.split("|"):
123
+ if not(len(line_label) == 0 or word in [".", ","]):
124
+ line_label += " "
125
+ line_label += word
126
+
127
+ image_path = os.path.join(base_folder, "sentences", form_id.split("-")[0], form_id, f"{line_id}.png")
128
+
129
+ subset = 'other'
130
+ if author_id in training_authors:
131
+ subset = 'train'
132
+ elif author_id in test_authors:
133
+ subset = 'test'
134
+
135
+ im = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
136
+ if im is not None and im.size > 1:
137
+ dataset_dict[subset][author_id].append(IAMImage(
138
+ im, line_label, image_count, line_id, None
139
+ ))
140
+ image_count += 1
141
+
142
+ return dataset_dict
143
+
144
+
145
+ def read_iam(base_folder: str) -> dict:
146
+ with open(os.path.join(base_folder, "forms.txt"), 'r') as f:
147
+ forms = [line.rstrip() for line in f if not line.startswith("#")]
148
+
149
+ training_authors, test_authors = get_author_ids(base_folder)
150
+
151
+ image_info = {}
152
+ with open(os.path.join(base_folder, "words.txt"), 'r') as f:
153
+ for line in f:
154
+ if not line.startswith("#"):
155
+ image_id, ok, threshold, x, y, w, h, tag, *content = line.rstrip().split(" ")
156
+ image_info[image_id] = {
157
+ 'ok': ok == 'ok',
158
+ 'threshold': threshold,
159
+ 'content': " ".join(content) if isinstance(content, list) else content,
160
+ 'bbox': [int(x), int(y), int(w), int(h)]
161
+ }
162
+
163
+ dataset_dict = {
164
+ 'train': defaultdict(list),
165
+ 'test': defaultdict(list),
166
+ 'other': defaultdict(list)
167
+ }
168
+
169
+ image_count = 0
170
+ err_count = 0
171
+
172
+ for form in forms:
173
+ form_id, writer_id, *_ = form.split(" ")
174
+ base_form = form_id.split("-")[0]
175
+
176
+ form_path = os.path.join(base_folder, "words", base_form, form_id)
177
+
178
+ for image_name in os.listdir(form_path):
179
+ image_id = image_name.split(".")[0]
180
+ info = image_info[image_id]
181
+
182
+ subset = 'other'
183
+ if writer_id in training_authors:
184
+ subset = 'train'
185
+ elif writer_id in test_authors:
186
+ subset = 'test'
187
+
188
+ if info['ok'] or not FILTER_ERR:
189
+ im = cv2.imread(os.path.join(form_path, image_name), cv2.IMREAD_GRAYSCALE)
190
+ if not info['ok'] and False:
191
+ cv2.destroyAllWindows()
192
+ print(info['content'])
193
+ cv2.imshow("image", im)
194
+ cv2.waitKey(0)
195
+
196
+ if im is not None and im.size > 1:
197
+ dataset_dict[subset][writer_id].append(IAMImage(
198
+ im, info['content'], image_count, "-".join(image_id.split("-")[:3]), info['bbox'], iam_image_id=image_id
199
+ ))
200
+ image_count += 1
201
+ else:
202
+ err_count += 1
203
+ print(f"Could not read image {image_name}, skipping")
204
+ else:
205
+ err_count += 1
206
+
207
+ assert not dataset_dict['train'].keys() & dataset_dict['test'].keys(), "Training and Testing set have common authors"
208
+
209
+ print(f"Skipped images: {err_count}")
210
+
211
+ return dataset_dict
212
+
213
+
214
+ def read_cvl_set(set_folder: str):
215
+ set_images = defaultdict(list)
216
+ words_path = os.path.join(set_folder, "words")
217
+
218
+ image_id = 0
219
+
220
+ for author_id in os.listdir(words_path):
221
+ author_path = os.path.join(words_path, author_id)
222
+
223
+ for image_file in os.listdir(author_path):
224
+ label = image_file.split("-")[-1].split(".")[0]
225
+ line_id = "-".join(image_file.split("-")[:-2])
226
+
227
+ stream = open(os.path.join(author_path, image_file), "rb")
228
+ bytes = bytearray(stream.read())
229
+ numpyarray = np.asarray(bytes, dtype=np.uint8)
230
+ image = cv2.imdecode(numpyarray, cv2.IMREAD_UNCHANGED)
231
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
232
+ if image is not None and image.size > 1:
233
+ set_images[int(author_id)].append(IAMImage(image, label, image_id, line_id))
234
+ image_id += 1
235
+
236
+ return set_images
237
+
238
+
239
+ def read_cvl(base_folder: str):
240
+ dataset_dict = {
241
+ 'test': read_cvl_set(os.path.join(base_folder, 'testset')),
242
+ 'train': read_cvl_set(os.path.join(base_folder, 'trainset'))
243
+ }
244
+
245
+ assert not dataset_dict['train'].keys() & dataset_dict[
246
+ 'test'].keys(), "Training and Testing set have common authors"
247
+
248
+ return dataset_dict
249
+
250
+ def pad_top(image: np.array, height: int) -> np.array:
251
+ result = np.ones((height, image.shape[1]), dtype=np.uint8) * 255
252
+ result[height - image.shape[0]:, :image.shape[1]] = image
253
+
254
+ return result
255
+
256
+
257
+ def scale_per_writer(writer_dict: dict, target_height: int, char_width: int = None) -> dict:
258
+ for author_id in writer_dict.keys():
259
+ max_height = max([image_dict.image.shape[0] for image_dict in writer_dict[author_id]])
260
+ scale_y = target_height / max_height
261
+
262
+ for image_dict in writer_dict[author_id]:
263
+ image = image_dict.image
264
+ scale_x = scale_y if char_width is None else len(image_dict.label) * char_width / image_dict.image.shape[1]
265
+ #image = cv2.resize(image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC)
266
+ image = resize(image, (int(image.shape[1] * scale_x), int(image.shape[0] * scale_y)))
267
+ image_dict.image = pad_top(image, target_height)
268
+
269
+ return writer_dict
270
+
271
+
272
+ def scale_images(writer_dict: dict, target_height: int, char_width: int = None) -> dict:
273
+ for author_id in writer_dict.keys():
274
+ for image_dict in writer_dict[author_id]:
275
+ scale_y = target_height / image_dict.image.shape[0]
276
+ scale_x = scale_y if char_width is None else len(image_dict.label) * char_width / image_dict.image.shape[1]
277
+ #image_dict.image = cv2.resize(image_dict.image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC)
278
+ image_dict.image = resize(image_dict.image, (int(image_dict.image.shape[1] * scale_x), target_height))
279
+ return writer_dict
280
+
281
+
282
+ def scale_word_width(writer_dict: dict):
283
+ for author_id in writer_dict.keys():
284
+ for image_dict in writer_dict[author_id]:
285
+ width = len(image_dict.label) * (image_dict.image.shape[0] / 2.0)
286
+ image_dict.image = resize(image_dict.image, (int(width), image_dict.image.shape[0]))
287
+ return writer_dict
288
+
289
+
290
+ def get_sentences(author_dict: dict):
291
+ collected = defaultdict(list)
292
+ for image in author_dict:
293
+ collected[image.line_id].append(image)
294
+
295
+ return [v for k, v in collected.items()]
296
+
297
+
298
+ def merge_author_words(author_words):
299
+ def try_left_merge(index: int):
300
+ if index > 0 and author_words[index - 1].line_id == author_words[index].line_id and not to_remove[index - 1] and not author_words[index - 1].label in TO_MERGE.keys():
301
+ merged = author_words[index - 1].merge(author_words[index])
302
+ author_words[index - 1] = merged
303
+ to_remove[index] = True
304
+ return True
305
+ return False
306
+
307
+ def try_right_merge(index: int):
308
+ if index < len(author_words) - 1 and author_words[index].line_id == author_words[index + 1].line_id and not to_remove[index + 1] and not author_words[index + 1].label in TO_MERGE.keys():
309
+ merged = iam_image.merge(author_words[index + 1])
310
+ author_words[index + 1] = merged
311
+ to_remove[index] = True
312
+ return True
313
+ return False
314
+
315
+ to_remove = [False for _ in range(len(author_words))]
316
+ for i in range(len(author_words)):
317
+ iam_image = author_words[i]
318
+ if iam_image.label in TO_MERGE.keys():
319
+ merge_type = TO_MERGE[iam_image.label] if TO_MERGE[iam_image.label] != 'random' else random.choice(['left', 'right'])
320
+ if merge_type == 'left':
321
+ if not try_left_merge(i):
322
+ if not try_right_merge(i):
323
+ print(f"Could not merge char: {iam_image.label}")
324
+ else:
325
+ if not try_right_merge(i):
326
+ if not try_left_merge(i):
327
+ print(f"Could not merge char: {iam_image.label}")
328
+
329
+ return [image for image, remove in zip(author_words, to_remove) if not remove], sum(to_remove)
330
+
331
+
332
+ def merge_punctuation(writer_dict: dict) -> dict:
333
+ for author_id in writer_dict.keys():
334
+ author_dict = writer_dict[author_id]
335
+
336
+ merged = 1
337
+ while merged > 0:
338
+ author_dict, merged = merge_author_words(author_dict)
339
+
340
+ writer_dict[author_id] = author_dict
341
+
342
+ return writer_dict
343
+
344
+
345
+ def filter_punctuation(writer_dict: dict) -> dict:
346
+ for author_id in writer_dict.keys():
347
+ author_list = [im for im in writer_dict[author_id] if im.label not in TO_MERGE.keys()]
348
+
349
+ writer_dict[author_id] = author_list
350
+
351
+ return writer_dict
352
+
353
+
354
+ def filter_by_width(writer_dict: dict, target_height: int = 32, min_width: int = 16, max_width: int = 17) -> dict:
355
+ def is_valid(iam_image: IAMImage) -> bool:
356
+ target_width = (target_height / iam_image.image.shape[0]) * iam_image.image.shape[1]
357
+ if len(iam_image.label) * min_width / 3 <= target_width <= len(iam_image.label) * max_width * 3:
358
+ return True
359
+ else:
360
+ return False
361
+
362
+ for author_id in writer_dict.keys():
363
+ author_list = [im for im in writer_dict[author_id] if is_valid(im)]
364
+
365
+ writer_dict[author_id] = author_list
366
+
367
+ return writer_dict
368
+
369
+
370
+ def write_data(dataset_dict: dict, location: str, height, punct_mode: str = 'none', author_scale: bool = False, uniform_char_width: bool = False):
371
+ assert punct_mode in ['none', 'filter', 'merge']
372
+ result = {}
373
+ for key in dataset_dict.keys():
374
+ result[key] = {}
375
+
376
+ subset_dict = dataset_dict[key]
377
+
378
+ subset_dict = filter_by_width(subset_dict)
379
+
380
+ if punct_mode == 'merge':
381
+ subset_dict = merge_punctuation(subset_dict)
382
+ elif punct_mode == 'filter':
383
+ subset_dict = filter_punctuation(subset_dict)
384
+
385
+ char_width = 16 if uniform_char_width else None
386
+
387
+ if author_scale:
388
+ subset_dict = scale_per_writer(subset_dict, height, char_width)
389
+ else:
390
+ subset_dict = scale_images(subset_dict, height, char_width)
391
+
392
+ for author_id in subset_dict:
393
+ author_images = []
394
+ for image_dict in subset_dict[author_id]:
395
+ author_images.append({
396
+ 'img': PIL.Image.fromarray(image_dict.image),
397
+ 'label': image_dict.label,
398
+ 'image_id': image_dict.image_id,
399
+ 'original_image_id': image_dict.iam_image_id
400
+ })
401
+ result[key][author_id] = author_images
402
+
403
+ with open(location, 'wb') as f:
404
+ pickle.dump(result, f)
405
+
406
+
407
+ def write_fid(dataset_dict: dict, location: str):
408
+ data = dataset_dict['test']
409
+ data = scale_images(data, 64, None)
410
+ for author in data.keys():
411
+ author_folder = os.path.join(location, author)
412
+ os.mkdir(author_folder)
413
+ count = 0
414
+ for image in data[author]:
415
+ img = image.image
416
+ cv2.imwrite(os.path.join(author_folder, f"{count}.png"), img.squeeze().astype(np.uint8))
417
+ count += 1
418
+
419
+
420
+ def write_images_per_author(dataset_dict: dict, output_file: str):
421
+ data = dataset_dict["test"]
422
+
423
+ result = {}
424
+
425
+ for author in data.keys():
426
+ author_images = [image.iam_image_id for image in data[author]]
427
+ result[author] = author_images
428
+
429
+ with open(output_file, 'w') as f:
430
+ json.dump(result, f)
431
+
432
+
433
+ def write_words(dataset_dict: dict, output_file):
434
+ data = dataset_dict['train']
435
+
436
+ all_words = []
437
+
438
+ for author in data.keys():
439
+ all_words.extend([image.label for image in data[author]])
440
+
441
+ with open(output_file, 'w') as f:
442
+ for word in all_words:
443
+ f.write(f"{word}\n")
444
+
445
+
446
+ if __name__ == "__main__":
447
+ data_path = r"D:\Datasets\IAM"
448
+ fid_location = r"E:/projects/evaluation/shtg_interface/data/reference_imgs/h64/iam"
449
+ height = 32
450
+ data_collection = {}
451
+
452
+ output_location = r"E:\projects\evaluation\shtg_interface\data\datasets"
453
+
454
+ data = read_iam(data_path)
455
+ test_data = dict(scale_word_width(data['test']))
456
+ train_data = dict(scale_word_width(data['train']))
457
+ test_data.update(train_data)
458
+ for key, value in test_data.items():
459
+ for image_object in value:
460
+ if len(image_object.label) <= 0 or image_object.image.size == 0:
461
+ continue
462
+ data_collection[image_object.iam_image_id] = {
463
+ 'img': image_object.image,
464
+ 'lbl': image_object.label,
465
+ 'author_id': key
466
+ }
467
+
468
+ with gzip.open(os.path.join(output_location, f"iam_w16_words_data.pkl.gz"), 'wb') as f:
469
+ pickle.dump(data_collection, f)
data/dataset.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torchvision.transforms as transforms
8
+ import os
9
+ import pickle
10
+ import numpy as np
11
+ from PIL import Image
12
+ from pathlib import Path
13
+
14
+
15
+ def get_dataset_path(dataset_name, height, file_suffix, datasets_path):
16
+ if file_suffix is not None:
17
+ filename = f'{dataset_name}-{height}-{file_suffix}.pickle'
18
+ else:
19
+ filename = f'{dataset_name}-{height}.pickle'
20
+
21
+ return os.path.join(datasets_path, filename)
22
+
23
+
24
+ def get_transform(grayscale=False, convert=True):
25
+ transform_list = []
26
+ if grayscale:
27
+ transform_list.append(transforms.Grayscale(1))
28
+
29
+ if convert:
30
+ transform_list += [transforms.ToTensor()]
31
+ if grayscale:
32
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
33
+ else:
34
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
35
+
36
+ return transforms.Compose(transform_list)
37
+
38
+
39
+ class TextDataset:
40
+
41
+ def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, min_virtual_size=0, validation=False, debug=False):
42
+ self.NUM_EXAMPLES = num_examples
43
+ self.debug = debug
44
+ self.min_virtual_size = min_virtual_size
45
+
46
+ subset = 'test' if validation else 'train'
47
+
48
+ # base_path=DATASET_PATHS
49
+ file_to_store = open(base_path, "rb")
50
+ self.IMG_DATA = pickle.load(file_to_store)[subset]
51
+ self.IMG_DATA = dict(list(self.IMG_DATA.items())) # [:NUM_WRITERS])
52
+ if 'None' in self.IMG_DATA.keys():
53
+ del self.IMG_DATA['None']
54
+
55
+ self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), [])))))
56
+ self.author_id = list(self.IMG_DATA.keys())
57
+
58
+ self.transform = get_transform(grayscale=True)
59
+ self.target_transform = target_transform
60
+
61
+ self.collate_fn = TextCollator(collator_resolution)
62
+
63
+ def __len__(self):
64
+ if self.debug:
65
+ return 16
66
+ return max(len(self.author_id), self.min_virtual_size)
67
+
68
+ @property
69
+ def num_writers(self):
70
+ return len(self.author_id)
71
+
72
+ def __getitem__(self, index):
73
+ index = index % len(self.author_id)
74
+
75
+ author_id = self.author_id[index]
76
+
77
+ self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id]
78
+ random_idxs = random.choices([i for i in range(len(self.IMG_DATA_AUTHOR))], k=self.NUM_EXAMPLES)
79
+
80
+ word_data = random.choice(self.IMG_DATA_AUTHOR)
81
+ real_img = self.transform(word_data['img'].convert('L'))
82
+ real_labels = word_data['label'].encode()
83
+
84
+ imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs]
85
+ slabels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs]
86
+
87
+ max_width = 192 # [img.shape[1] for img in imgs]
88
+
89
+ imgs_pad = []
90
+ imgs_wids = []
91
+
92
+ for img in imgs:
93
+ img_height, img_width = img.shape[0], img.shape[1]
94
+ output_img = np.ones((img_height, max_width), dtype='float32') * 255.0
95
+ output_img[:, :img_width] = img[:, :max_width]
96
+
97
+ imgs_pad.append(self.transform(Image.fromarray(output_img.astype(np.uint8))))
98
+ imgs_wids.append(img_width)
99
+
100
+ imgs_pad = torch.cat(imgs_pad, 0)
101
+
102
+ item = {
103
+ 'simg': imgs_pad, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)]
104
+ 'swids': imgs_wids, # widths of the N images [list(N)]
105
+ 'img': real_img, # the input image [1, H (32), W]
106
+ 'label': real_labels, # the label of the input image [byte]
107
+ 'img_path': 'img_path',
108
+ 'idx': 'indexes',
109
+ 'wcl': index, # id of the author [int],
110
+ 'slabels': slabels,
111
+ 'author_id': author_id
112
+ }
113
+ return item
114
+
115
+ def get_stats(self):
116
+ char_counts = defaultdict(lambda: 0)
117
+ total = 0
118
+
119
+ for author in self.IMG_DATA.keys():
120
+ for data in self.IMG_DATA[author]:
121
+ for char in data['label']:
122
+ char_counts[char] += 1
123
+ total += 1
124
+
125
+ char_counts = {k: 1.0 / (v / total) for k, v in char_counts.items()}
126
+
127
+ return char_counts
128
+
129
+
130
+ class TextCollator(object):
131
+ def __init__(self, resolution):
132
+ self.resolution = resolution
133
+
134
+ def __call__(self, batch):
135
+ if isinstance(batch[0], list):
136
+ batch = sum(batch, [])
137
+ img_path = [item['img_path'] for item in batch]
138
+ width = [item['img'].shape[2] for item in batch]
139
+ indexes = [item['idx'] for item in batch]
140
+ simgs = torch.stack([item['simg'] for item in batch], 0)
141
+ wcls = torch.Tensor([item['wcl'] for item in batch])
142
+ swids = torch.Tensor([item['swids'] for item in batch])
143
+ imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)],
144
+ dtype=torch.float32)
145
+ for idx, item in enumerate(batch):
146
+ try:
147
+ imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
148
+ except:
149
+ print(imgs.shape)
150
+ item = {'img': imgs, 'img_path': img_path, 'idx': indexes, 'simg': simgs, 'swids': swids, 'wcl': wcls}
151
+ if 'label' in batch[0].keys():
152
+ labels = [item['label'] for item in batch]
153
+ item['label'] = labels
154
+ if 'slabels' in batch[0].keys():
155
+ slabels = [item['slabels'] for item in batch]
156
+ item['slabels'] = np.array(slabels)
157
+ if 'z' in batch[0].keys():
158
+ z = torch.stack([item['z'] for item in batch])
159
+ item['z'] = z
160
+ return item
161
+
162
+
163
+ class CollectionTextDataset(Dataset):
164
+ def __init__(self, datasets, datasets_path, dataset_class, file_suffix=None, height=32, **kwargs):
165
+ self.datasets = {}
166
+ for dataset_name in sorted(datasets.split(',')):
167
+ dataset_file = get_dataset_path(dataset_name, height, file_suffix, datasets_path)
168
+ dataset = dataset_class(dataset_file, **kwargs)
169
+ self.datasets[dataset_name] = dataset
170
+ self.alphabet = ''.join(sorted(set(''.join(d.alphabet for d in self.datasets.values()))))
171
+
172
+ def __len__(self):
173
+ return sum(len(d) for d in self.datasets.values())
174
+
175
+ @property
176
+ def num_writers(self):
177
+ return sum(d.num_writers for d in self.datasets.values())
178
+
179
+ def __getitem__(self, index):
180
+ for dataset in self.datasets.values():
181
+ if index < len(dataset):
182
+ return dataset[index]
183
+ index -= len(dataset)
184
+ raise IndexError
185
+
186
+ def get_dataset(self, index):
187
+ for dataset_name, dataset in self.datasets.items():
188
+ if index < len(dataset):
189
+ return dataset_name
190
+ index -= len(dataset)
191
+ raise IndexError
192
+
193
+ def collate_fn(self, batch):
194
+ return self.datasets[self.get_dataset(0)].collate_fn(batch)
195
+
196
+
197
+ class FidDataset(Dataset):
198
+ def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, mode='train', style_dataset=None):
199
+ self.NUM_EXAMPLES = num_examples
200
+
201
+ # base_path=DATASET_PATHS
202
+ with open(base_path, "rb") as f:
203
+ self.IMG_DATA = pickle.load(f)
204
+
205
+ self.IMG_DATA = self.IMG_DATA[mode]
206
+ if 'None' in self.IMG_DATA.keys():
207
+ del self.IMG_DATA['None']
208
+
209
+ self.STYLE_IMG_DATA = None
210
+ if style_dataset is not None:
211
+ with open(style_dataset, "rb") as f:
212
+ self.STYLE_IMG_DATA = pickle.load(f)
213
+
214
+ self.STYLE_IMG_DATA = self.STYLE_IMG_DATA[mode]
215
+ if 'None' in self.STYLE_IMG_DATA.keys():
216
+ del self.STYLE_IMG_DATA['None']
217
+
218
+ self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), [])))))
219
+ self.author_id = sorted(self.IMG_DATA.keys())
220
+
221
+ self.transform = get_transform(grayscale=True)
222
+ self.target_transform = target_transform
223
+ self.dataset_size = sum(len(samples) for samples in self.IMG_DATA.values())
224
+ self.collate_fn = TextCollator(collator_resolution)
225
+
226
+ def __len__(self):
227
+ return self.dataset_size
228
+
229
+ @property
230
+ def num_writers(self):
231
+ return len(self.author_id)
232
+
233
+ def __getitem__(self, index):
234
+ NUM_SAMPLES = self.NUM_EXAMPLES
235
+ sample, author_id = None, None
236
+ for author_id, samples in self.IMG_DATA.items():
237
+ if index < len(samples):
238
+ sample, author_id = samples[index], author_id
239
+ break
240
+ index -= len(samples)
241
+
242
+ real_image = self.transform(sample['img'].convert('L'))
243
+ real_label = sample['label'].encode()
244
+
245
+ style_dataset = self.STYLE_IMG_DATA if self.STYLE_IMG_DATA is not None else self.IMG_DATA
246
+
247
+ author_style_images = style_dataset[author_id]
248
+ random_idxs = np.random.choice(len(author_style_images), NUM_SAMPLES, replace=True)
249
+ style_images = [np.array(author_style_images[idx]['img'].convert('L')) for idx in random_idxs]
250
+
251
+ max_width = 192
252
+
253
+ imgs_pad = []
254
+ imgs_wids = []
255
+
256
+ for img in style_images:
257
+ img = 255 - img
258
+ img_height, img_width = img.shape[0], img.shape[1]
259
+ outImg = np.zeros((img_height, max_width), dtype='float32')
260
+ outImg[:, :img_width] = img[:, :max_width]
261
+
262
+ img = 255 - outImg
263
+
264
+ imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8))))
265
+ imgs_wids.append(img_width)
266
+
267
+ imgs_pad = torch.cat(imgs_pad, 0)
268
+
269
+ item = {
270
+ 'simg': imgs_pad, # widths of the N images [list(N)]
271
+ 'swids': imgs_wids, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)]
272
+ 'img': real_image, # the input image [1, H (32), W]
273
+ 'label': real_label, # the label of the input image [byte]
274
+ 'img_path': 'img_path',
275
+ 'idx': sample['img_id'] if 'img_id' in sample.keys() else sample['image_id'],
276
+ 'wcl': int(author_id) # id of the author [int]
277
+ }
278
+ return item
279
+
280
+
281
+ class FolderDataset:
282
+ def __init__(self, folder_path, num_examples=15, word_lengths=None):
283
+ folder_path = Path(folder_path)
284
+ self.imgs = list([p for p in folder_path.iterdir() if not p.suffix == '.txt'])
285
+ self.transform = get_transform(grayscale=True)
286
+ self.num_examples = num_examples
287
+ self.word_lengths = word_lengths
288
+
289
+ def __len__(self):
290
+ return len(self.imgs)
291
+
292
+ def sample_style(self):
293
+ random_idxs = np.random.choice(len(self.imgs), self.num_examples, replace=False)
294
+ image_names = [self.imgs[idx].stem for idx in random_idxs]
295
+ imgs = [Image.open(self.imgs[idx]).convert('L') for idx in random_idxs]
296
+ if self.word_lengths is None:
297
+ imgs = [img.resize((img.size[0] * 32 // img.size[1], 32), Image.BILINEAR) for img in imgs]
298
+ else:
299
+ imgs = [img.resize((self.word_lengths[name] * 16, 32), Image.BILINEAR) for img, name in zip(imgs, image_names)]
300
+ imgs = [np.array(img) for img in imgs]
301
+
302
+ max_width = 192 # [img.shape[1] for img in imgs]
303
+
304
+ imgs_pad = []
305
+ imgs_wids = []
306
+
307
+ for img in imgs:
308
+ img = 255 - img
309
+ img_height, img_width = img.shape[0], img.shape[1]
310
+ outImg = np.zeros((img_height, max_width), dtype='float32')
311
+ outImg[:, :img_width] = img[:, :max_width]
312
+
313
+ img = 255 - outImg
314
+
315
+ imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8))))
316
+ imgs_wids.append(img_width)
317
+
318
+ imgs_pad = torch.cat(imgs_pad, 0)
319
+
320
+ item = {
321
+ 'simg': imgs_pad, # widths of the N images [list(N)]
322
+ 'swids': imgs_wids, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)]
323
+ }
324
+ return item
data/iam_test.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def test_split():
5
+ iam_path = r"C:\Users\bramv\Documents\Werk\Research\Unimore\datasets\IAM"
6
+
7
+ original_set_names = ["trainset.txt", "validationset1.txt", "validationset2.txt", "testset.txt"]
8
+ original_set_ids = []
9
+
10
+ print("ORIGINAL IAM")
11
+ print("---------------------")
12
+
13
+ for set_name in original_set_names:
14
+ with open(os.path.join(iam_path, set_name), 'r') as f:
15
+ set_form_ids = ["-".join(l.rstrip().split("-")[:-1]) for l in f]
16
+
17
+ form_to_id = {}
18
+ with open(os.path.join(iam_path, "forms.txt"), 'r') as f:
19
+ for line in f:
20
+ if line.startswith("#"):
21
+ continue
22
+ form, id, *_ = line.split(" ")
23
+ assert form not in form_to_id.keys() or form_to_id[form] == id
24
+ form_to_id[form] = int(id)
25
+
26
+ set_authors = [form_to_id[form] for form in set_form_ids]
27
+
28
+ set_authors = set(sorted(set_authors))
29
+ original_set_ids.append(set_authors)
30
+ print(f"{set_name} count: {len(set_authors)}")
31
+
32
+ htg_set_names = ["gan.iam.tr_va.gt.filter27", "gan.iam.test.gt.filter27"]
33
+
34
+ print("\n\nHTG IAM")
35
+ print("---------------------")
36
+
37
+ for set_name in htg_set_names:
38
+ with open(os.path.join(iam_path, set_name), 'r') as f:
39
+ set_authors = [int(l.split(",")[0]) for l in f]
40
+
41
+ set_authors = set(set_authors)
42
+
43
+ print(f"{set_name} count: {len(set_authors)}")
44
+ for name, original_set in zip(original_set_names, original_set_ids):
45
+ intr = set_authors.intersection(original_set)
46
+ print(f"\t intersection with {name}: {len(intr)}")
47
+
48
+
49
+
50
+ if __name__ == "__main__":
51
+ test_split()
data/show_dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ import shutil
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ from data.dataset import get_transform
11
+
12
+
13
+ def summarize_dataset(data: dict):
14
+ print(f"Training authors: {len(data['train'].keys())} \t Testing authors: {len(data['test'].keys())}")
15
+ training_images = sum([len(data['train'][k]) for k in data['train'].keys()])
16
+ testing_images = sum([len(data['test'][k]) for k in data['test'].keys()])
17
+ print(f"Training images: {training_images} \t Testing images: {testing_images}")
18
+
19
+
20
+ def compare_data(path_a: str, path_b: str):
21
+ with open(path_a, 'rb') as f:
22
+ data_a = pickle.load(f)
23
+ summarize_dataset(data_a)
24
+
25
+ with open(path_b, 'rb') as f:
26
+ data_b = pickle.load(f)
27
+ summarize_dataset(data_b)
28
+
29
+ training_a = data_a['train']
30
+ training_b = data_b['train']
31
+
32
+ training_a = {int(k): v for k, v in training_a.items()}
33
+ training_b = {int(k): v for k, v in training_b.items()}
34
+
35
+ while True:
36
+ author = random.choice(list(training_a.keys()))
37
+
38
+ if author in training_b.keys():
39
+ author_images_a = [np.array(im_dict["img"]) for im_dict in training_a[author]]
40
+ author_images_b = [np.array(im_dict["img"]) for im_dict in training_b[author]]
41
+
42
+ labels_a = [str(im_dict["label"]) for im_dict in training_a[author]]
43
+ labels_b = [str(im_dict["label"]) for im_dict in training_b[author]]
44
+
45
+ vis_a = np.hstack(author_images_a[:10])
46
+ vis_b = np.hstack(author_images_b[:10])
47
+
48
+ cv2.imshow("Author a", vis_a)
49
+ cv2.imshow("Author b", vis_b)
50
+
51
+ cv2.waitKey(0)
52
+
53
+ else:
54
+ print(f"Author: {author} not found in second dataset")
55
+
56
+
57
+ def show_dataset(path: str, samples: int = 10):
58
+ with open(path, 'rb') as f:
59
+ data = pickle.load(f)
60
+ summarize_dataset(data)
61
+
62
+ training = data['train']
63
+
64
+ author = training['013']
65
+ author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in author]
66
+
67
+ for img in author_images:
68
+ cv2.imshow('image', img)
69
+ cv2.waitKey(0)
70
+
71
+ for author in list(training.keys()):
72
+
73
+ author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in training[author]]
74
+ labels = [str(im_dict["label"]) for im_dict in training[author]]
75
+
76
+ vis = np.hstack(author_images[:samples])
77
+ print(f"Author: {author}")
78
+ cv2.destroyAllWindows()
79
+ cv2.imshow("vis", vis)
80
+ cv2.waitKey(0)
81
+
82
+
83
+ def test_transform(path: str):
84
+ with open(path, 'rb') as f:
85
+ data = pickle.load(f)
86
+ summarize_dataset(data)
87
+
88
+ training = data['train']
89
+ transform = get_transform(grayscale=True)
90
+
91
+ for author_id in training.keys():
92
+ author = training[author_id]
93
+ for image_dict in author:
94
+ original_image = image_dict['img'].convert('L')
95
+ transformed_image = transform(original_image).detach().numpy()
96
+ restored_image = (((transformed_image + 1) / 2) * 255).astype(np.uint8)
97
+ restored_image = np.squeeze(restored_image)
98
+ original_image = np.array(original_image)
99
+
100
+ wrong_pixels = (original_image != restored_image).astype(np.uint8) * 255
101
+
102
+ combined = np.hstack((restored_image, original_image, wrong_pixels))
103
+
104
+ cv2.imshow("original", original_image)
105
+ cv2.imshow("restored", restored_image)
106
+ cv2.imshow("combined", combined)
107
+
108
+ f, ax = plt.subplots(1, 2)
109
+ ax[0].hist(original_image.flatten())
110
+ ax[1].hist(restored_image.flatten())
111
+ plt.show()
112
+
113
+ cv2.waitKey(0)
114
+
115
+ def dump_words():
116
+ data_path = r"..\files\IAM-32.pickle"
117
+
118
+ p_mark = 'point'
119
+ p = '.'
120
+
121
+ with open(data_path, 'rb') as f:
122
+ data = pickle.load(f)
123
+
124
+ training = data['train']
125
+
126
+ target_folder = f"../saved_images/debug/{p_mark}"
127
+
128
+ if os.path.exists(target_folder):
129
+ shutil.rmtree(target_folder)
130
+
131
+ os.mkdir(target_folder)
132
+
133
+ count = 0
134
+
135
+ for author in list(training.keys()):
136
+
137
+ author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in training[author]]
138
+ labels = [str(im_dict["label"]) for im_dict in training[author]]
139
+
140
+ for img, label in zip(author_images, labels):
141
+ if p in label:
142
+ cv2.imwrite(os.path.join(target_folder, f"{count}.png"), img)
143
+ count += 1
144
+
145
+
146
+ if __name__ == "__main__":
147
+ test_transform("../files/IAM-32.pickle")
148
+ #show_dataset("../files/IAM-32.pickle")
149
+ #compare_data(r"../files/IAM-32.pickle", r"../files/_IAM-32.pickle")
files/IAM-32-pa.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92bff8330e8f404b5f382846266257b5cac45d6c27908df5c3ee7d0c77a0ee95
3
+ size 245981914
files/IAM-32.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c56d4055470c26a30dbbdf7f2e232eb86ffc714b803651dbac5576ee2bc97937
3
+ size 590113103
files/cvl_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b47fe3ffe291bb3e52db0643125a99206840884181ed21312bcbe2cdd86303f0
3
+ size 163050271
files/english_words.txt ADDED
The diff for this file is too large to render. See raw diff
 
files/files ADDED
@@ -0,0 +1 @@
 
 
1
+ files
files/hwt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:999f85148e34e30242c1aa9ed7063c9dbc9da008f868ed26cb6ed923f9d8c0bd
3
+ size 163050271
files/resnet_18_pretrained.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf5f5f6a94152dc4b0e9f2e390d658ef621efead3824cd494d3a82a6c8ceb5e0
3
+ size 48833885
files/unifont.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0804979068f0d169b343fbe0fe8d7ff478165d07a671fcf52e20f625db8e7f9f
3
+ size 16978300
files/vatr.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65b67f1738bf74d5bf612f7f35e2c8c9560568d7efe422beb9132e1bb68bbef8
3
+ size 565758212
files/vatrpp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c02f950d19cf3df3cfa6fe97114557e16a51bd3b910da6b5a2359a29851b84b6
3
+ size 561198056
generate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from generate import generate_text, generate_authors, generate_fid, generate_page, generate_ocr, generate_ocr_msgpack
3
+ from generate.ocr import generate_ocr_reference
4
+ from util.misc import add_vatr_args
5
+
6
+ if __name__ == '__main__':
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("action", choices=['text', 'fid', 'page', 'authors', 'ocr'])
9
+
10
+ parser.add_argument("-s", "--style-folder", default='files/style_samples/00', type=str)
11
+ parser.add_argument("-t", "--text", default='That\'s one small step for man, one giant leap for mankind ΑαΒβΓγΔδ', type=str)
12
+ parser.add_argument("--text-path", default=None, type=str, help='Path to text file with texts to generate')
13
+ parser.add_argument("-c", "--checkpoint", default='files/vatr.pth', type=str)
14
+ parser.add_argument("-o", "--output", default=None, type=str)
15
+ parser.add_argument("--count", default=1000, type=int)
16
+ parser.add_argument("-a", "--align", action='store_true')
17
+ parser.add_argument("--at-once", action='store_true')
18
+ parser.add_argument("--output-style", action='store_true')
19
+ parser.add_argument("-d", "--dataset-path", type=str)
20
+ parser.add_argument("--target-dataset-path", type=str, default=None)
21
+ parser.add_argument("--charset-file", type=str, default=None)
22
+ parser.add_argument("--interp-styles", action='store_true')
23
+
24
+ parser.add_argument("--test-only", action='store_true')
25
+ parser.add_argument("--fake-only", action='store_true')
26
+ parser.add_argument("--all-epochs", action='store_true')
27
+ parser.add_argument("--long-tail", action='store_true')
28
+ parser.add_argument("--msgpack", action='store_true')
29
+ parser.add_argument("--reference", action='store_true')
30
+ parser.add_argument("--test-set", action='store_true')
31
+
32
+ parser = add_vatr_args(parser)
33
+ args = parser.parse_args()
34
+
35
+ if args.action == 'text':
36
+ generate_text(args)
37
+ elif args.action == 'authors':
38
+ generate_authors(args)
39
+ elif args.action == 'fid':
40
+ generate_fid(args)
41
+ elif args.action == 'page':
42
+ generate_page(args)
43
+ elif args.action == 'ocr':
44
+ if args.msgpack:
45
+ generate_ocr_msgpack(args)
46
+ elif args.reference:
47
+ generate_ocr_reference(args)
48
+ else:
49
+ generate_ocr(args)
generate/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from generate.text import generate_text
2
+ from generate.fid import generate_fid
3
+ from generate.authors import generate_authors
4
+ from generate.page import generate_page
5
+ from generate.ocr import generate_ocr, generate_ocr_msgpack
generate/authors.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from data.dataset import CollectionTextDataset, TextDataset
8
+ from generate.util import stack_lines
9
+ from generate.writer import Writer
10
+
11
+
12
+ def generate_authors(args):
13
+ dataset = CollectionTextDataset(
14
+ args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
15
+ collator_resolution=args.resolution, validation=args.test_set
16
+ )
17
+
18
+ args.num_writers = dataset.num_writers
19
+
20
+ writer = Writer(args.checkpoint, args, only_generator=True)
21
+
22
+ if args.text.endswith(".txt"):
23
+ with open(args.text, 'r') as f:
24
+ lines = [l.rstrip() for l in f]
25
+ else:
26
+ lines = [args.text]
27
+
28
+ output_dir = "saved_images/author_samples/"
29
+ if os.path.exists(output_dir):
30
+ shutil.rmtree(output_dir)
31
+ os.mkdir(output_dir)
32
+
33
+ fakes, author_ids, style_images = writer.generate_authors(lines, dataset, args.align, args.at_once)
34
+
35
+ for fake, author_id, style in zip(fakes, author_ids, style_images):
36
+ author_dir = os.path.join(output_dir, str(author_id))
37
+ os.mkdir(author_dir)
38
+
39
+ for i, line in enumerate(fake):
40
+ cv2.imwrite(os.path.join(author_dir, f"line_{i}.png"), line)
41
+
42
+ total = stack_lines(fake)
43
+ cv2.imwrite(os.path.join(author_dir, "total.png"), total)
44
+
45
+ if args.output_style:
46
+ for i, image in enumerate(style):
47
+ cv2.imwrite(os.path.join(author_dir, f"style_{i}.png"), image)
48
+
generate/fid.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.utils.data
6
+
7
+ from data.dataset import FidDataset
8
+ from generate.writer import Writer
9
+
10
+
11
+ def generate_fid(args):
12
+ if 'iam' in args.target_dataset_path.lower():
13
+ args.num_writers = 339
14
+ elif 'cvl' in args.target_dataset_path.lower():
15
+ args.num_writers = 283
16
+ else:
17
+ raise ValueError
18
+
19
+ args.vocab_size = len(args.alphabet)
20
+
21
+ dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path)
22
+ train_loader = torch.utils.data.DataLoader(
23
+ dataset_train,
24
+ batch_size=args.batch_size,
25
+ shuffle=False,
26
+ num_workers=args.num_workers,
27
+ pin_memory=True, drop_last=False,
28
+ collate_fn=dataset_train.collate_fn
29
+ )
30
+
31
+ dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path)
32
+ test_loader = torch.utils.data.DataLoader(
33
+ dataset_test,
34
+ batch_size=args.batch_size,
35
+ shuffle=False,
36
+ num_workers=0,
37
+ pin_memory=True, drop_last=False,
38
+ collate_fn=dataset_test.collate_fn
39
+ )
40
+
41
+ args.output = 'saved_images' if args.output is None else args.output
42
+ args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")
43
+
44
+ model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1]
45
+ model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr"
46
+ model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")
47
+
48
+ if not args.all_epochs:
49
+ writer = Writer(args.checkpoint, args, only_generator=True)
50
+ if not args.test_only:
51
+ writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail)
52
+ writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail)
53
+ else:
54
+ epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f])
55
+ generate_real = True
56
+
57
+ for epoch in epochs:
58
+ checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth")
59
+ writer = Writer(checkpoint_path, args, only_generator=True)
60
+ writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail)
61
+ generate_real = False
62
+
63
+ print('Done')
generate/ocr.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import cv2
5
+ import msgpack
6
+ import torch
7
+
8
+ from data.dataset import CollectionTextDataset, TextDataset, FolderDataset, FidDataset, get_dataset_path
9
+ from generate.writer import Writer
10
+ from util.text import get_generator
11
+
12
+
13
+ def generate_ocr(args):
14
+ """
15
+ Generate OCR training data. Words generated are from given text generator.
16
+ """
17
+ dataset = CollectionTextDataset(
18
+ args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
19
+ collator_resolution=args.resolution, validation=True
20
+ )
21
+ args.num_writers = dataset.num_writers
22
+
23
+ writer = Writer(args.checkpoint, args, only_generator=True)
24
+
25
+ generator = get_generator(args)
26
+
27
+ writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, text_generator=generator)
28
+
29
+
30
+ def generate_ocr_reference(args):
31
+ """
32
+ Generate OCR training data. Words generated are words from given dataset. Reference words are also saved.
33
+ """
34
+ dataset = CollectionTextDataset(
35
+ args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
36
+ collator_resolution=args.resolution, validation=True
37
+ )
38
+
39
+ #dataset = FidDataset(get_dataset_path(args.dataset, 32, args.file_suffix, 'files'), mode='test', collator_resolution=args.resolution)
40
+
41
+ args.num_writers = dataset.num_writers
42
+
43
+ writer = Writer(args.checkpoint, args, only_generator=True)
44
+
45
+ writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, long_tail=args.long_tail)
46
+
47
+
48
+ def generate_ocr_msgpack(args):
49
+ """
50
+ Generate OCR dataset. Words generated are specified in given msgpack file
51
+ """
52
+ dataset = FolderDataset(args.dataset_path)
53
+ args.num_writers = 339
54
+
55
+ if args.charset_file:
56
+ charset = msgpack.load(open(args.charset_file, 'rb'), use_list=False, strict_map_key=False)
57
+ args.alphabet = "".join(charset['char2idx'].keys())
58
+
59
+ writer = Writer(args.checkpoint, args, only_generator=True)
60
+
61
+ lines = msgpack.load(open(args.text_path, 'rb'), use_list=False)
62
+
63
+ print(f"Generating {len(lines)} to {args.output}")
64
+
65
+ for i, (filename, target) in enumerate(lines):
66
+ if not os.path.exists(os.path.join(args.output, filename)):
67
+ style = torch.unsqueeze(dataset.sample_style()['simg'], dim=0).to(args.device)
68
+ fake = writer.create_fake_sentence(style, target, at_once=True)
69
+
70
+ cv2.imwrite(os.path.join(args.output, filename), fake)
71
+
72
+ print(f"Done")
generate/page.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+ from data.dataset import CollectionTextDataset, TextDataset
8
+ from models.model import VATr
9
+ from util.loading import load_checkpoint, load_generator
10
+
11
+
12
+ def generate_page(args):
13
+ args.output = 'vatr' if args.output is None else args.output
14
+
15
+ args.vocab_size = len(args.alphabet)
16
+
17
+ dataset = CollectionTextDataset(
18
+ args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
19
+ collator_resolution=args.resolution
20
+ )
21
+ datasetval = CollectionTextDataset(
22
+ args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
23
+ collator_resolution=args.resolution, validation=True
24
+ )
25
+
26
+ args.num_writers = dataset.num_writers
27
+
28
+ model = VATr(args)
29
+ checkpoint = torch.load(args.checkpoint, map_location=args.device)
30
+ model = load_generator(model, checkpoint)
31
+
32
+ train_loader = torch.utils.data.DataLoader(
33
+ dataset,
34
+ batch_size=8,
35
+ shuffle=True,
36
+ num_workers=0,
37
+ pin_memory=True, drop_last=True,
38
+ collate_fn=dataset.collate_fn)
39
+
40
+ val_loader = torch.utils.data.DataLoader(
41
+ datasetval,
42
+ batch_size=8,
43
+ shuffle=True,
44
+ num_workers=0,
45
+ pin_memory=True, drop_last=True,
46
+ collate_fn=datasetval.collate_fn)
47
+
48
+ data_train = next(iter(train_loader))
49
+ data_val = next(iter(val_loader))
50
+
51
+ model.eval()
52
+ with torch.no_grad():
53
+ page = model._generate_page(data_train['simg'].to(args.device), data_val['swids'])
54
+ page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids'])
55
+
56
+ cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_train.png"), (page * 255).astype(np.uint8))
57
+ cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_val.png"), (page_val * 255).astype(np.uint8))
generate/text.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+
5
+ from generate.writer import Writer
6
+
7
+
8
+ def generate_text(args):
9
+ if args.text_path is not None:
10
+ with open(args.text_path, 'r') as f:
11
+ args.text = f.read()
12
+ args.text = args.text.splitlines()
13
+ args.output = 'files/output.png' if args.output is None else args.output
14
+ args.output = Path(args.output)
15
+ args.output.parent.mkdir(parents=True, exist_ok=True)
16
+ args.num_writers = 0
17
+
18
+ writer = Writer(args.checkpoint, args, only_generator=True)
19
+ writer.set_style_folder(args.style_folder)
20
+ fakes = writer.generate(args.text, args.align)
21
+ for i, fake in enumerate(fakes):
22
+ dst_path = args.output.parent / (args.output.stem + f'_{i:03d}' + args.output.suffix)
23
+ cv2.imwrite(str(dst_path), fake)
24
+ print('Done')
generate/util.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def stack_lines(lines: list, h_gap: int = 6):
5
+ width = max([im.shape[1] for im in lines])
6
+ height = (lines[0].shape[0] + h_gap) * len(lines)
7
+
8
+ result = np.ones((height, width)) * 255
9
+
10
+ y_pos = 0
11
+ for line in lines:
12
+ result[y_pos:y_pos + line.shape[0], 0:line.shape[1]] = line
13
+ y_pos += line.shape[0] + h_gap
14
+
15
+ return result
generate/writer.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import shutil
5
+ from collections import defaultdict
6
+ import time
7
+ from datetime import timedelta
8
+ from pathlib import Path
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+
14
+ from data.dataset import FolderDataset
15
+ from models.model import VATr
16
+ from util.loading import load_checkpoint, load_generator
17
+ from util.misc import FakeArgs
18
+ from util.text import TextGenerator
19
+ from util.vision import detect_text_bounds
20
+
21
+
22
+ def get_long_tail_chars():
23
+ with open(f"files/longtail.txt", 'r') as f:
24
+ chars = [c.rstrip() for c in f]
25
+
26
+ chars.remove('')
27
+
28
+ return chars
29
+
30
+
31
+ class Writer:
32
+ def __init__(self, checkpoint_path, args, only_generator: bool = False):
33
+ self.model = VATr(args)
34
+ checkpoint = torch.load(checkpoint_path, map_location=args.device)
35
+ load_checkpoint(self.model, checkpoint) if not only_generator else load_generator(self.model, checkpoint)
36
+ self.model.eval()
37
+ self.style_dataset = None
38
+
39
+ def set_style_folder(self, style_folder, num_examples=15):
40
+ word_lengths = None
41
+ if os.path.exists(os.path.join(style_folder, "word_lengths.txt")):
42
+ word_lengths = {}
43
+ with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f:
44
+ for line in f:
45
+ word, length = line.rstrip().split(",")
46
+ word_lengths[word] = int(length)
47
+
48
+ self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths)
49
+
50
+ @torch.no_grad()
51
+ def generate(self, texts, align_words: bool = False, at_once: bool = False):
52
+ if isinstance(texts, str):
53
+ texts = [texts]
54
+ if self.style_dataset is None:
55
+ raise Exception('Style is not set')
56
+
57
+ fakes = []
58
+ for i, text in enumerate(texts, 1):
59
+ print(f'[{i}/{len(texts)}] Generating for text: {text}')
60
+ style = self.style_dataset.sample_style()
61
+ style_images = style['simg'].unsqueeze(0).to(self.model.args.device)
62
+
63
+ fake = self.create_fake_sentence(style_images, text, align_words, at_once)
64
+
65
+ fakes.append(fake)
66
+ return fakes
67
+
68
+ @torch.no_grad()
69
+ def create_fake_sentence(self, style_images, text, align_words=False, at_once=False):
70
+ text = "".join([c for c in text if c in self.model.args.alphabet])
71
+
72
+ text = text.split() if not at_once else [text]
73
+ gap = np.ones((32, 16))
74
+
75
+ text_encode, len_text, encode_pos = self.model.netconverter.encode(text)
76
+ text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
77
+
78
+ fake = self.model._generate_fakes(style_images, text_encode, len_text)
79
+ if not at_once:
80
+ if align_words:
81
+ fake = self.stitch_words(fake, show_lines=False)
82
+ else:
83
+ fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16]
84
+ else:
85
+ fake = fake[0]
86
+ fake = (fake * 255).astype(np.uint8)
87
+
88
+ return fake
89
+
90
+ @torch.no_grad()
91
+ def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False):
92
+ fakes = []
93
+ author_ids = []
94
+ style = []
95
+
96
+ for item in dataset:
97
+ print(f"Generating author {item['wcl']}")
98
+ style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
99
+
100
+ generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text]
101
+
102
+ fakes.append(generated_lines)
103
+ author_ids.append(item['author_id'])
104
+ style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8))
105
+
106
+ return fakes, author_ids, style
107
+
108
+ @torch.no_grad()
109
+ def generate_characters(self, dataset, characters: str):
110
+ """
111
+ Generate each of the given characters for each of the authors in the dataset.
112
+ """
113
+ fakes = []
114
+
115
+ text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters])
116
+ text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
117
+
118
+ for item in dataset:
119
+ print(f"Generating author {item['wcl']}")
120
+ style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
121
+ fake = self.model.netG.evaluate(style_images, text_encode)
122
+
123
+ fakes.append(fake)
124
+
125
+ return fakes
126
+
127
+ @torch.no_grad()
128
+ def generate_batch(self, style_imgs, text):
129
+ """
130
+ Given a batch of style images and text, generate images using the model
131
+ """
132
+ device = self.model.args.device
133
+ text_encode, _, _ = self.model.netconverter.encode(text)
134
+ fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device))
135
+ return fakes
136
+
137
+ @torch.no_grad()
138
+ def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False):
139
+ def create_and_write(style, text, interpolated=False):
140
+ nonlocal image_counter, annotations
141
+
142
+ text_encode, len_text, encode_pos = self.model.netconverter.encode([text])
143
+ text_encode = text_encode.to(self.model.args.device)
144
+
145
+ fake = self.model.netG.generate(style, text_encode)
146
+
147
+ fake = (fake + 1) / 2
148
+ fake = fake.cpu().numpy()
149
+ fake = np.squeeze((fake * 255).astype(np.uint8))
150
+
151
+ image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png"
152
+
153
+ cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake)
154
+
155
+ annotations.append((image_filename, text))
156
+
157
+ image_counter += 1
158
+
159
+ image_counter = 0
160
+ annotations = []
161
+ previous_style = None
162
+ long_tail_chars = get_long_tail_chars()
163
+
164
+ os.mkdir(os.path.join(output_folder, "generated"))
165
+ if text_generator is None:
166
+ os.mkdir(os.path.join(output_folder, "reference"))
167
+
168
+ while image_counter < number:
169
+ author_index = random.randint(0, len(dataset) - 1)
170
+ item = dataset[author_index]
171
+
172
+ style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
173
+ style = self.model.netG.compute_style(style_images)
174
+
175
+ if interpolate_style and previous_style is not None:
176
+ factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0))
177
+ intermediate_style = torch.lerp(previous_style, style, factor)
178
+ text = text_generator.generate()
179
+
180
+ create_and_write(intermediate_style, text, interpolated=True)
181
+
182
+ if text_generator is not None:
183
+ text = text_generator.generate()
184
+ else:
185
+ text = str(item['label'].decode())
186
+
187
+ if long_tail and not any(c in long_tail_chars for c in text):
188
+ continue
189
+
190
+ fake = (item['img'] + 1) / 2
191
+ fake = fake.cpu().numpy()
192
+ fake = np.squeeze((fake * 255).astype(np.uint8))
193
+
194
+ image_filename = f"{image_counter}.png"
195
+
196
+ cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake)
197
+
198
+ create_and_write(style, text)
199
+
200
+ previous_style = style
201
+
202
+ if text_generator is None:
203
+ with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr:
204
+ fr.write(f"filename,words\n")
205
+ for annotation in annotations:
206
+ fr.write(f"{annotation[0]},{annotation[1]}\n")
207
+
208
+ with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg:
209
+ fg.write(f"filename,words\n")
210
+ for annotation in annotations:
211
+ fg.write(f"{annotation[0]},{annotation[1]}\n")
212
+
213
+
214
+ @staticmethod
215
+ def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False):
216
+ gap_width = 16
217
+
218
+ bottom_lines = []
219
+ top_lines = []
220
+ for i in range(len(words)):
221
+ b, t = detect_text_bounds(words[i])
222
+ bottom_lines.append(b)
223
+ top_lines.append(t)
224
+ if show_lines:
225
+ words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0))
226
+ words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0))
227
+
228
+ bottom_lines = np.array(bottom_lines, dtype=float)
229
+
230
+ if scale_words:
231
+ top_lines = np.array(top_lines, dtype=float)
232
+ gaps = bottom_lines - top_lines
233
+ target_gap = np.mean(gaps)
234
+ scales = target_gap / gaps
235
+
236
+ bottom_lines *= scales
237
+ top_lines *= scales
238
+ words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)]
239
+
240
+ highest = np.max(bottom_lines)
241
+ offsets = highest - bottom_lines
242
+ height = np.max(offsets + [word.shape[0] for word in words])
243
+
244
+ result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words])))
245
+
246
+ x_pos = 0
247
+ for bottom_line, word in zip(bottom_lines, words):
248
+ offset = int(highest - bottom_line)
249
+
250
+ result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word
251
+
252
+ x_pos += word.shape[1] + gap_width
253
+
254
+ return result
255
+
256
+ @torch.no_grad()
257
+ def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False):
258
+ if not isinstance(path, Path):
259
+ path = Path(path)
260
+
261
+ path.mkdir(exist_ok=True, parents=True)
262
+
263
+ appendix = f"{split}" if not long_tail_only else f"{split}_lt"
264
+
265
+ real_base = path / f'real_{appendix}'
266
+ fake_base = path / model_tag / f'fake_{appendix}'
267
+
268
+ if real_base.exists() and not fake_only:
269
+ shutil.rmtree(real_base)
270
+
271
+ if fake_base.exists():
272
+ shutil.rmtree(fake_base)
273
+
274
+ real_base.mkdir(exist_ok=True)
275
+ fake_base.mkdir(exist_ok=True, parents=True)
276
+
277
+ print('Saving images...')
278
+
279
+ print(' Saving images on {}'.format(str(real_base)))
280
+ print(' Saving images on {}'.format(str(fake_base)))
281
+
282
+ long_tail_chars = get_long_tail_chars()
283
+ counter = 0
284
+ ann = defaultdict(lambda: {})
285
+ start_time = time.time()
286
+ for step, data in enumerate(loader):
287
+ style_images = data['simg'].to(self.model.args.device)
288
+
289
+ texts = [l.decode('utf-8') for l in data['label']]
290
+ texts = [t.encode('utf-8') for t in texts]
291
+ eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts)
292
+ eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1)
293
+
294
+ vis_style = np.vstack(style_images[0].detach().cpu().numpy())
295
+ vis_style = ((vis_style + 1) / 2) * 255
296
+
297
+ fakes = self.model.netG.evaluate(style_images, eval_text_encode)
298
+ fake_images = torch.cat(fakes, 1).detach().cpu().numpy()
299
+ real_images = data['img'].detach().cpu().numpy()
300
+ writer_ids = data['wcl'].int().tolist()
301
+
302
+ for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])):
303
+ lb = lb.decode()
304
+ ann[f"{wid:03d}"][f'{img_id:05d}'] = lb
305
+ img_id = f'{img_id:05d}.png'
306
+
307
+ is_long_tail = any(c in long_tail_chars for c in lb)
308
+
309
+ if long_tail_only and not is_long_tail:
310
+ continue
311
+
312
+ fake_img_path = fake_base / f"{wid:03d}" / img_id
313
+ fake_img_path.parent.mkdir(exist_ok=True, parents=True)
314
+ cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2))
315
+
316
+ if not fake_only:
317
+ real_img_path = real_base / f"{wid:03d}" / img_id
318
+ real_img_path.parent.mkdir(exist_ok=True, parents=True)
319
+ cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2))
320
+
321
+ counter += 1
322
+
323
+ eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1)
324
+ eta = str(timedelta(seconds=eta))
325
+ if step % 100 == 0:
326
+ print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}')
327
+
328
+ with open(path / 'ann.json', 'w') as f:
329
+ json.dump(ann, f)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.46.2"
4
+ }
hwt/config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_noise": false,
3
+ "alphabet": "Only thewigsofrcvdampbkuq.A-210xT5'MDL,RYHJ\"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%",
4
+ "architectures": [
5
+ "VATrPP"
6
+ ],
7
+ "augment_ocr": false,
8
+ "batch_size": 8,
9
+ "corpus": "standard",
10
+ "d_crop_size": null,
11
+ "d_lr": 1e-05,
12
+ "dataset": "IAM",
13
+ "device": "cuda",
14
+ "english_words_path": "files/english_words.txt",
15
+ "epochs": 100000,
16
+ "feat_model_path": "files/resnet_18_pretrained.pth",
17
+ "file_suffix": null,
18
+ "g_lr": 5e-05,
19
+ "img_height": 32,
20
+ "is_cycle": false,
21
+ "label_encoder": "default",
22
+ "model_type": "emuru",
23
+ "no_ocr_loss": false,
24
+ "no_writer_loss": false,
25
+ "num_examples": 15,
26
+ "num_words": 3,
27
+ "num_workers": 0,
28
+ "num_writers": 339,
29
+ "ocr_lr": 5e-05,
30
+ "query_input": "unifont",
31
+ "resolution": 16,
32
+ "save_model": 5,
33
+ "save_model_history": 500,
34
+ "save_model_path": "saved_models",
35
+ "seed": 742,
36
+ "special_alphabet": "\u0391\u03b1\u0392\u03b2\u0393\u03b3\u0394\u03b4\u0395\u03b5\u0396\u03b6\u0397\u03b7\u0398\u03b8\u0399\u03b9\u039a\u03ba\u039b\u03bb\u039c\u03bc\u039d\u03bd\u039e\u03be\u039f\u03bf\u03a0\u03c0\u03a1\u03c1\u03a3\u03c3\u03c2\u03a4\u03c4\u03a5\u03c5\u03a6\u03c6\u03a7\u03c7\u03a8\u03c8\u03a9\u03c9",
37
+ "tag": "debug",
38
+ "text_aug_type": "proportional",
39
+ "text_augment_strength": 0.0,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.46.2",
42
+ "vocab_size": 80,
43
+ "w_lr": 5e-05,
44
+ "wandb": false,
45
+ "writer_loss_weight": 1.0
46
+ }
hwt/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.46.2"
4
+ }
hwt/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c9bd990cdfd3a2a1683af05705c1f9a17b7f58b580a33853b0d0af7c57f7f2e
3
+ size 560965208
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b1e4b7cae23652acd5c559117d06ef42fdd5317da2a5e0bc94ea44d8c0eb1ff
3
+ size 560965208
modeling_vatrpp.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration_vatrpp import VATrPPConfig
3
+ import json
4
+ import os
5
+ import random
6
+ import shutil
7
+ from collections import defaultdict
8
+ import time
9
+ from datetime import timedelta
10
+ from pathlib import Path
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+
16
+ from data.dataset import FolderDataset
17
+ from models.model import VATr
18
+ from util.loading import load_checkpoint, load_generator
19
+ from util.misc import FakeArgs
20
+ from util.text import TextGenerator
21
+ from util.vision import detect_text_bounds
22
+ from torchvision.transforms.functional import to_pil_image
23
+
24
+
25
+ def get_long_tail_chars():
26
+ with open(f"files/longtail.txt", 'r') as f:
27
+ chars = [c.rstrip() for c in f]
28
+
29
+ chars.remove('')
30
+
31
+ return chars
32
+
33
+ class VATrPP(PreTrainedModel):
34
+ config_class = VATrPPConfig
35
+
36
+ def __init__(self, config: VATrPPConfig) -> None:
37
+ super().__init__(config)
38
+ self.model = VATr(config)
39
+ self.model.eval()
40
+
41
+ def set_style_folder(self, style_folder, num_examples=15):
42
+ word_lengths = None
43
+ if os.path.exists(os.path.join(style_folder, "word_lengths.txt")):
44
+ word_lengths = {}
45
+ with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f:
46
+ for line in f:
47
+ word, length = line.rstrip().split(",")
48
+ word_lengths[word] = int(length)
49
+
50
+ self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths)
51
+
52
+ @torch.no_grad()
53
+ def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False):
54
+ style_images = style_imgs.unsqueeze(0).to(self.model.args.device)
55
+
56
+ fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once)
57
+ return to_pil_image(fake)
58
+
59
+ # @torch.no_grad()
60
+ # def generate(self, texts, align_words: bool = False, at_once: bool = False):
61
+ # if isinstance(texts, str):
62
+ # texts = [texts]
63
+ # if self.style_dataset is None:
64
+ # raise Exception('Style is not set')
65
+
66
+ # fakes = []
67
+ # for i, text in enumerate(texts, 1):
68
+ # print(f'[{i}/{len(texts)}] Generating for text: {text}')
69
+ # style = self.style_dataset.sample_style()
70
+ # style_images = style['simg'].unsqueeze(0).to(self.model.args.device)
71
+
72
+ # fake = self.create_fake_sentence(style_images, text, align_words, at_once)
73
+
74
+ # fakes.append(fake)
75
+ # return fakes
76
+
77
+ @torch.no_grad()
78
+ def create_fake_sentence(self, style_images, text, align_words=False, at_once=False):
79
+ text = "".join([c for c in text if c in self.model.args.alphabet])
80
+
81
+ text = text.split() if not at_once else [text]
82
+ gap = np.ones((32, 16))
83
+
84
+ text_encode, len_text, encode_pos = self.model.netconverter.encode(text)
85
+ text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
86
+
87
+ fake = self.model._generate_fakes(style_images, text_encode, len_text)
88
+ if not at_once:
89
+ if align_words:
90
+ fake = self.stitch_words(fake, show_lines=False)
91
+ else:
92
+ fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16]
93
+ else:
94
+ fake = fake[0]
95
+ fake = (fake * 255).astype(np.uint8)
96
+
97
+ return fake
98
+
99
+ @torch.no_grad()
100
+ def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False):
101
+ fakes = []
102
+ author_ids = []
103
+ style = []
104
+
105
+ for item in dataset:
106
+ print(f"Generating author {item['wcl']}")
107
+ style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
108
+
109
+ generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text]
110
+
111
+ fakes.append(generated_lines)
112
+ author_ids.append(item['author_id'])
113
+ style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8))
114
+
115
+ return fakes, author_ids, style
116
+
117
+ @torch.no_grad()
118
+ def generate_characters(self, dataset, characters: str):
119
+ """
120
+ Generate each of the given characters for each of the authors in the dataset.
121
+ """
122
+ fakes = []
123
+
124
+ text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters])
125
+ text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
126
+
127
+ for item in dataset:
128
+ print(f"Generating author {item['wcl']}")
129
+ style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
130
+ fake = self.model.netG.evaluate(style_images, text_encode)
131
+
132
+ fakes.append(fake)
133
+
134
+ return fakes
135
+
136
+ @torch.no_grad()
137
+ def generate_batch(self, style_imgs, text):
138
+ """
139
+ Given a batch of style images and text, generate images using the model
140
+ """
141
+ device = self.model.args.device
142
+ text_encode, _, _ = self.model.netconverter.encode(text)
143
+ fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device))
144
+ return fakes
145
+
146
+ @torch.no_grad()
147
+ def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False):
148
+ def create_and_write(style, text, interpolated=False):
149
+ nonlocal image_counter, annotations
150
+
151
+ text_encode, len_text, encode_pos = self.model.netconverter.encode([text])
152
+ text_encode = text_encode.to(self.model.args.device)
153
+
154
+ fake = self.model.netG.generate(style, text_encode)
155
+
156
+ fake = (fake + 1) / 2
157
+ fake = fake.cpu().numpy()
158
+ fake = np.squeeze((fake * 255).astype(np.uint8))
159
+
160
+ image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png"
161
+
162
+ cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake)
163
+
164
+ annotations.append((image_filename, text))
165
+
166
+ image_counter += 1
167
+
168
+ image_counter = 0
169
+ annotations = []
170
+ previous_style = None
171
+ long_tail_chars = get_long_tail_chars()
172
+
173
+ os.mkdir(os.path.join(output_folder, "generated"))
174
+ if text_generator is None:
175
+ os.mkdir(os.path.join(output_folder, "reference"))
176
+
177
+ while image_counter < number:
178
+ author_index = random.randint(0, len(dataset) - 1)
179
+ item = dataset[author_index]
180
+
181
+ style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
182
+ style = self.model.netG.compute_style(style_images)
183
+
184
+ if interpolate_style and previous_style is not None:
185
+ factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0))
186
+ intermediate_style = torch.lerp(previous_style, style, factor)
187
+ text = text_generator.generate()
188
+
189
+ create_and_write(intermediate_style, text, interpolated=True)
190
+
191
+ if text_generator is not None:
192
+ text = text_generator.generate()
193
+ else:
194
+ text = str(item['label'].decode())
195
+
196
+ if long_tail and not any(c in long_tail_chars for c in text):
197
+ continue
198
+
199
+ fake = (item['img'] + 1) / 2
200
+ fake = fake.cpu().numpy()
201
+ fake = np.squeeze((fake * 255).astype(np.uint8))
202
+
203
+ image_filename = f"{image_counter}.png"
204
+
205
+ cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake)
206
+
207
+ create_and_write(style, text)
208
+
209
+ previous_style = style
210
+
211
+ if text_generator is None:
212
+ with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr:
213
+ fr.write(f"filename,words\n")
214
+ for annotation in annotations:
215
+ fr.write(f"{annotation[0]},{annotation[1]}\n")
216
+
217
+ with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg:
218
+ fg.write(f"filename,words\n")
219
+ for annotation in annotations:
220
+ fg.write(f"{annotation[0]},{annotation[1]}\n")
221
+
222
+
223
+ @staticmethod
224
+ def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False):
225
+ gap_width = 16
226
+
227
+ bottom_lines = []
228
+ top_lines = []
229
+ for i in range(len(words)):
230
+ b, t = detect_text_bounds(words[i])
231
+ bottom_lines.append(b)
232
+ top_lines.append(t)
233
+ if show_lines:
234
+ words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0))
235
+ words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0))
236
+
237
+ bottom_lines = np.array(bottom_lines, dtype=float)
238
+
239
+ if scale_words:
240
+ top_lines = np.array(top_lines, dtype=float)
241
+ gaps = bottom_lines - top_lines
242
+ target_gap = np.mean(gaps)
243
+ scales = target_gap / gaps
244
+
245
+ bottom_lines *= scales
246
+ top_lines *= scales
247
+ words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)]
248
+
249
+ highest = np.max(bottom_lines)
250
+ offsets = highest - bottom_lines
251
+ height = np.max(offsets + [word.shape[0] for word in words])
252
+
253
+ result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words])))
254
+
255
+ x_pos = 0
256
+ for bottom_line, word in zip(bottom_lines, words):
257
+ offset = int(highest - bottom_line)
258
+
259
+ result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word
260
+
261
+ x_pos += word.shape[1] + gap_width
262
+
263
+ return result
264
+
265
+ @torch.no_grad()
266
+ def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False):
267
+ if not isinstance(path, Path):
268
+ path = Path(path)
269
+
270
+ path.mkdir(exist_ok=True, parents=True)
271
+
272
+ appendix = f"{split}" if not long_tail_only else f"{split}_lt"
273
+
274
+ real_base = path / f'real_{appendix}'
275
+ fake_base = path / model_tag / f'fake_{appendix}'
276
+
277
+ if real_base.exists() and not fake_only:
278
+ shutil.rmtree(real_base)
279
+
280
+ if fake_base.exists():
281
+ shutil.rmtree(fake_base)
282
+
283
+ real_base.mkdir(exist_ok=True)
284
+ fake_base.mkdir(exist_ok=True, parents=True)
285
+
286
+ print('Saving images...')
287
+
288
+ print(' Saving images on {}'.format(str(real_base)))
289
+ print(' Saving images on {}'.format(str(fake_base)))
290
+
291
+ long_tail_chars = get_long_tail_chars()
292
+ counter = 0
293
+ ann = defaultdict(lambda: {})
294
+ start_time = time.time()
295
+ for step, data in enumerate(loader):
296
+ style_images = data['simg'].to(self.model.args.device)
297
+
298
+ texts = [l.decode('utf-8') for l in data['label']]
299
+ texts = [t.encode('utf-8') for t in texts]
300
+ eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts)
301
+ eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1)
302
+
303
+ vis_style = np.vstack(style_images[0].detach().cpu().numpy())
304
+ vis_style = ((vis_style + 1) / 2) * 255
305
+
306
+ fakes = self.model.netG.evaluate(style_images, eval_text_encode)
307
+ fake_images = torch.cat(fakes, 1).detach().cpu().numpy()
308
+ real_images = data['img'].detach().cpu().numpy()
309
+ writer_ids = data['wcl'].int().tolist()
310
+
311
+ for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])):
312
+ lb = lb.decode()
313
+ ann[f"{wid:03d}"][f'{img_id:05d}'] = lb
314
+ img_id = f'{img_id:05d}.png'
315
+
316
+ is_long_tail = any(c in long_tail_chars for c in lb)
317
+
318
+ if long_tail_only and not is_long_tail:
319
+ continue
320
+
321
+ fake_img_path = fake_base / f"{wid:03d}" / img_id
322
+ fake_img_path.parent.mkdir(exist_ok=True, parents=True)
323
+ cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2))
324
+
325
+ if not fake_only:
326
+ real_img_path = real_base / f"{wid:03d}" / img_id
327
+ real_img_path.parent.mkdir(exist_ok=True, parents=True)
328
+ cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2))
329
+
330
+ counter += 1
331
+
332
+ eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1)
333
+ eta = str(timedelta(seconds=eta))
334
+ if step % 100 == 0:
335
+ print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}')
336
+
337
+ with open(path / 'ann.json', 'w') as f:
338
+ json.dump(ann, f)
models/BigGAN_layers.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Layers
2
+ This file contains various layers for the BigGAN models.
3
+ '''
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import init
8
+ import torch.optim as optim
9
+ import torch.nn.functional as F
10
+ from torch.nn import Parameter as P
11
+
12
+ from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
13
+
14
+ # Projection of x onto y
15
+ def proj(x, y):
16
+ return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
17
+
18
+
19
+ # Orthogonalize x wrt list of vectors ys
20
+ def gram_schmidt(x, ys):
21
+ for y in ys:
22
+ x = x - proj(x, y)
23
+ return x
24
+
25
+
26
+ # Apply num_itrs steps of the power method to estimate top N singular values.
27
+ def power_iteration(W, u_, update=True, eps=1e-12):
28
+ # Lists holding singular vectors and values
29
+ us, vs, svs = [], [], []
30
+ for i, u in enumerate(u_):
31
+ # Run one step of the power iteration
32
+ with torch.no_grad():
33
+ v = torch.matmul(u, W)
34
+ # Run Gram-Schmidt to subtract components of all other singular vectors
35
+ v = F.normalize(gram_schmidt(v, vs), eps=eps)
36
+ # Add to the list
37
+ vs += [v]
38
+ # Update the other singular vector
39
+ u = torch.matmul(v, W.t())
40
+ # Run Gram-Schmidt to subtract components of all other singular vectors
41
+ u = F.normalize(gram_schmidt(u, us), eps=eps)
42
+ # Add to the list
43
+ us += [u]
44
+ if update:
45
+ u_[i][:] = u
46
+ # Compute this singular value and add it to the list
47
+ svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
48
+ # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
49
+ return svs, us, vs
50
+
51
+
52
+ # Convenience passthrough function
53
+ class identity(nn.Module):
54
+ def forward(self, input):
55
+ return input
56
+
57
+
58
+ # Spectral normalization base class
59
+ class SN(object):
60
+ def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
61
+ # Number of power iterations per step
62
+ self.num_itrs = num_itrs
63
+ # Number of singular values
64
+ self.num_svs = num_svs
65
+ # Transposed?
66
+ self.transpose = transpose
67
+ # Epsilon value for avoiding divide-by-0
68
+ self.eps = eps
69
+ # Register a singular vector for each sv
70
+ for i in range(self.num_svs):
71
+ self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
72
+ self.register_buffer('sv%d' % i, torch.ones(1))
73
+
74
+ # Singular vectors (u side)
75
+ @property
76
+ def u(self):
77
+ return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
78
+
79
+ # Singular values;
80
+ # note that these buffers are just for logging and are not used in training.
81
+ @property
82
+ def sv(self):
83
+ return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
84
+
85
+ # Compute the spectrally-normalized weight
86
+ def W_(self):
87
+ W_mat = self.weight.view(self.weight.size(0), -1)
88
+ if self.transpose:
89
+ W_mat = W_mat.t()
90
+ # Apply num_itrs power iterations
91
+ for _ in range(self.num_itrs):
92
+ svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
93
+ # Update the svs
94
+ if self.training:
95
+ with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
96
+ for i, sv in enumerate(svs):
97
+ self.sv[i][:] = sv
98
+ return self.weight / svs[0]
99
+
100
+
101
+ # 2D Conv layer with spectral norm
102
+ class SNConv2d(nn.Conv2d, SN):
103
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
104
+ padding=0, dilation=1, groups=1, bias=True,
105
+ num_svs=1, num_itrs=1, eps=1e-12):
106
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
107
+ padding, dilation, groups, bias)
108
+ SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
109
+
110
+ def forward(self, x):
111
+ return F.conv2d(x, self.W_(), self.bias, self.stride,
112
+ self.padding, self.dilation, self.groups)
113
+
114
+
115
+ # Linear layer with spectral norm
116
+ class SNLinear(nn.Linear, SN):
117
+ def __init__(self, in_features, out_features, bias=True,
118
+ num_svs=1, num_itrs=1, eps=1e-12):
119
+ nn.Linear.__init__(self, in_features, out_features, bias)
120
+ SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
121
+
122
+ def forward(self, x):
123
+ return F.linear(x, self.W_(), self.bias)
124
+
125
+
126
+ # Embedding layer with spectral norm
127
+ # We use num_embeddings as the dim instead of embedding_dim here
128
+ # for convenience sake
129
+ class SNEmbedding(nn.Embedding, SN):
130
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
131
+ max_norm=None, norm_type=2, scale_grad_by_freq=False,
132
+ sparse=False, _weight=None,
133
+ num_svs=1, num_itrs=1, eps=1e-12):
134
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
135
+ max_norm, norm_type, scale_grad_by_freq,
136
+ sparse, _weight)
137
+ SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
138
+
139
+ def forward(self, x):
140
+ return F.embedding(x, self.W_())
141
+
142
+
143
+ # A non-local block as used in SA-GAN
144
+ # Note that the implementation as described in the paper is largely incorrect;
145
+ # refer to the released code for the actual implementation.
146
+ class Attention(nn.Module):
147
+ def __init__(self, ch, which_conv=SNConv2d, name='attention'):
148
+ super(Attention, self).__init__()
149
+ # Channel multiplier
150
+ self.ch = ch
151
+ self.which_conv = which_conv
152
+ self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
153
+ self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
154
+ self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
155
+ self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
156
+ # Learnable gain parameter
157
+ self.gamma = P(torch.tensor(0.), requires_grad=True)
158
+
159
+ def forward(self, x, y=None):
160
+ # Apply convs
161
+ theta = self.theta(x)
162
+ phi = F.max_pool2d(self.phi(x), [2, 2])
163
+ g = F.max_pool2d(self.g(x), [2, 2])
164
+ # Perform reshapes
165
+ theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
166
+ try:
167
+ phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
168
+ except:
169
+ print(phi.shape)
170
+ g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
171
+ # Matmul and softmax to get attention maps
172
+ beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
173
+ # Attention map times g path
174
+ o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
175
+ return self.gamma * o + x
176
+
177
+
178
+ # Fused batchnorm op
179
+ def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
180
+ # Apply scale and shift--if gain and bias are provided, fuse them here
181
+ # Prepare scale
182
+ scale = torch.rsqrt(var + eps)
183
+ # If a gain is provided, use it
184
+ if gain is not None:
185
+ scale = scale * gain
186
+ # Prepare shift
187
+ shift = mean * scale
188
+ # If bias is provided, use it
189
+ if bias is not None:
190
+ shift = shift - bias
191
+ return x * scale - shift
192
+ # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
193
+
194
+
195
+ # Manual BN
196
+ # Calculate means and variances using mean-of-squares minus mean-squared
197
+ def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
198
+ # Cast x to float32 if necessary
199
+ float_x = x.float()
200
+ # Calculate expected value of x (m) and expected value of x**2 (m2)
201
+ # Mean of x
202
+ m = torch.mean(float_x, [0, 2, 3], keepdim=True)
203
+ # Mean of x squared
204
+ m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
205
+ # Calculate variance as mean of squared minus mean squared.
206
+ var = (m2 - m ** 2)
207
+ # Cast back to float 16 if necessary
208
+ var = var.type(x.type())
209
+ m = m.type(x.type())
210
+ # Return mean and variance for updating stored mean/var if requested
211
+ if return_mean_var:
212
+ return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
213
+ else:
214
+ return fused_bn(x, m, var, gain, bias, eps)
215
+
216
+
217
+ # My batchnorm, supports standing stats
218
+ class myBN(nn.Module):
219
+ def __init__(self, num_channels, eps=1e-5, momentum=0.1):
220
+ super(myBN, self).__init__()
221
+ # momentum for updating running stats
222
+ self.momentum = momentum
223
+ # epsilon to avoid dividing by 0
224
+ self.eps = eps
225
+ # Momentum
226
+ self.momentum = momentum
227
+ # Register buffers
228
+ self.register_buffer('stored_mean', torch.zeros(num_channels))
229
+ self.register_buffer('stored_var', torch.ones(num_channels))
230
+ self.register_buffer('accumulation_counter', torch.zeros(1))
231
+ # Accumulate running means and vars
232
+ self.accumulate_standing = False
233
+
234
+ # reset standing stats
235
+ def reset_stats(self):
236
+ self.stored_mean[:] = 0
237
+ self.stored_var[:] = 0
238
+ self.accumulation_counter[:] = 0
239
+
240
+ def forward(self, x, gain, bias):
241
+ if self.training:
242
+ out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
243
+ # If accumulating standing stats, increment them
244
+ if self.accumulate_standing:
245
+ self.stored_mean[:] = self.stored_mean + mean.data
246
+ self.stored_var[:] = self.stored_var + var.data
247
+ self.accumulation_counter += 1.0
248
+ # If not accumulating standing stats, take running averages
249
+ else:
250
+ self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
251
+ self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
252
+ return out
253
+ # If not in training mode, use the stored statistics
254
+ else:
255
+ mean = self.stored_mean.view(1, -1, 1, 1)
256
+ var = self.stored_var.view(1, -1, 1, 1)
257
+ # If using standing stats, divide them by the accumulation counter
258
+ if self.accumulate_standing:
259
+ mean = mean / self.accumulation_counter
260
+ var = var / self.accumulation_counter
261
+ return fused_bn(x, mean, var, gain, bias, self.eps)
262
+
263
+
264
+ # Simple function to handle groupnorm norm stylization
265
+ def groupnorm(x, norm_style):
266
+ # If number of channels specified in norm_style:
267
+ if 'ch' in norm_style:
268
+ ch = int(norm_style.split('_')[-1])
269
+ groups = max(int(x.shape[1]) // ch, 1)
270
+ # If number of groups specified in norm style
271
+ elif 'grp' in norm_style:
272
+ groups = int(norm_style.split('_')[-1])
273
+ # If neither, default to groups = 16
274
+ else:
275
+ groups = 16
276
+ return F.group_norm(x, groups)
277
+
278
+
279
+ # Class-conditional bn
280
+ # output size is the number of channels, input size is for the linear layers
281
+ # Andy's Note: this class feels messy but I'm not really sure how to clean it up
282
+ # Suggestions welcome! (By which I mean, refactor this and make a pull request
283
+ # if you want to make this more readable/usable).
284
+ class ccbn(nn.Module):
285
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
286
+ cross_replica=False, mybn=False, norm_style='bn', ):
287
+ super(ccbn, self).__init__()
288
+ self.output_size, self.input_size = output_size, input_size
289
+ # Prepare gain and bias layers
290
+ self.gain = which_linear(input_size, output_size)
291
+ self.bias = which_linear(input_size, output_size)
292
+ # epsilon to avoid dividing by 0
293
+ self.eps = eps
294
+ # Momentum
295
+ self.momentum = momentum
296
+ # Use cross-replica batchnorm?
297
+ self.cross_replica = cross_replica
298
+ # Use my batchnorm?
299
+ self.mybn = mybn
300
+ # Norm style?
301
+ self.norm_style = norm_style
302
+
303
+ if self.cross_replica:
304
+ self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
305
+ elif self.mybn:
306
+ self.bn = myBN(output_size, self.eps, self.momentum)
307
+ elif self.norm_style in ['bn', 'in']:
308
+ self.register_buffer('stored_mean', torch.zeros(output_size))
309
+ self.register_buffer('stored_var', torch.ones(output_size))
310
+
311
+ def forward(self, x, y):
312
+ # Calculate class-conditional gains and biases
313
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
314
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
315
+ # If using my batchnorm
316
+ if self.mybn or self.cross_replica:
317
+ return self.bn(x, gain=gain, bias=bias)
318
+ # else:
319
+ else:
320
+ if self.norm_style == 'bn':
321
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
322
+ self.training, 0.1, self.eps)
323
+ elif self.norm_style == 'in':
324
+ out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
325
+ self.training, 0.1, self.eps)
326
+ elif self.norm_style == 'gn':
327
+ out = groupnorm(x, self.normstyle)
328
+ elif self.norm_style == 'nonorm':
329
+ out = x
330
+ return out * gain + bias
331
+
332
+ def extra_repr(self):
333
+ s = 'out: {output_size}, in: {input_size},'
334
+ s += ' cross_replica={cross_replica}'
335
+ return s.format(**self.__dict__)
336
+
337
+
338
+ # Normal, non-class-conditional BN
339
+ class bn(nn.Module):
340
+ def __init__(self, output_size, eps=1e-5, momentum=0.1,
341
+ cross_replica=False, mybn=False):
342
+ super(bn, self).__init__()
343
+ self.output_size = output_size
344
+ # Prepare gain and bias layers
345
+ self.gain = P(torch.ones(output_size), requires_grad=True)
346
+ self.bias = P(torch.zeros(output_size), requires_grad=True)
347
+ # epsilon to avoid dividing by 0
348
+ self.eps = eps
349
+ # Momentum
350
+ self.momentum = momentum
351
+ # Use cross-replica batchnorm?
352
+ self.cross_replica = cross_replica
353
+ # Use my batchnorm?
354
+ self.mybn = mybn
355
+
356
+ if self.cross_replica:
357
+ self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
358
+ elif mybn:
359
+ self.bn = myBN(output_size, self.eps, self.momentum)
360
+ # Register buffers if neither of the above
361
+ else:
362
+ self.register_buffer('stored_mean', torch.zeros(output_size))
363
+ self.register_buffer('stored_var', torch.ones(output_size))
364
+
365
+ def forward(self, x, y=None):
366
+ if self.cross_replica or self.mybn:
367
+ gain = self.gain.view(1, -1, 1, 1)
368
+ bias = self.bias.view(1, -1, 1, 1)
369
+ return self.bn(x, gain=gain, bias=bias)
370
+ else:
371
+ return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
372
+ self.bias, self.training, self.momentum, self.eps)
373
+
374
+
375
+ # Generator blocks
376
+ # Note that this class assumes the kernel size and padding (and any other
377
+ # settings) have been selected in the main generator module and passed in
378
+ # through the which_conv arg. Similar rules apply with which_bn (the input
379
+ # size [which is actually the number of channels of the conditional info] must
380
+ # be preselected)
381
+ class GBlock(nn.Module):
382
+ def __init__(self, in_channels, out_channels,
383
+ which_conv1=nn.Conv2d, which_conv2=nn.Conv2d, which_bn=bn, activation=None,
384
+ upsample=None):
385
+ super(GBlock, self).__init__()
386
+
387
+ self.in_channels, self.out_channels = in_channels, out_channels
388
+ self.which_conv1, self.which_conv2, self.which_bn = which_conv1, which_conv2, which_bn
389
+ self.activation = activation
390
+ self.upsample = upsample
391
+ # Conv layers
392
+ self.conv1 = self.which_conv1(self.in_channels, self.out_channels)
393
+ self.conv2 = self.which_conv2(self.out_channels, self.out_channels)
394
+ self.learnable_sc = in_channels != out_channels or upsample
395
+ if self.learnable_sc:
396
+ self.conv_sc = self.which_conv1(in_channels, out_channels,
397
+ kernel_size=1, padding=0)
398
+ # Batchnorm layers
399
+ self.bn1 = self.which_bn(in_channels)
400
+ self.bn2 = self.which_bn(out_channels)
401
+ # upsample layers
402
+ self.upsample = upsample
403
+
404
+ def forward(self, x, y):
405
+ h = self.activation(self.bn1(x, y))
406
+ # h = self.activation(x)
407
+ # h=x
408
+ if self.upsample:
409
+ h = self.upsample(h)
410
+ x = self.upsample(x)
411
+ h = self.conv1(h)
412
+ h = self.activation(self.bn2(h, y))
413
+ # h = self.activation(h)
414
+ h = self.conv2(h)
415
+ if self.learnable_sc:
416
+ x = self.conv_sc(x)
417
+ return h + x
418
+
419
+
420
+ # Residual block for the discriminator
421
+ class DBlock(nn.Module):
422
+ def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
423
+ preactivation=False, activation=None, downsample=None, ):
424
+ super(DBlock, self).__init__()
425
+ self.in_channels, self.out_channels = in_channels, out_channels
426
+ # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
427
+ self.hidden_channels = self.out_channels if wide else self.in_channels
428
+ self.which_conv = which_conv
429
+ self.preactivation = preactivation
430
+ self.activation = activation
431
+ self.downsample = downsample
432
+
433
+ # Conv layers
434
+ self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
435
+ self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
436
+ self.learnable_sc = True if (in_channels != out_channels) or downsample else False
437
+ if self.learnable_sc:
438
+ self.conv_sc = self.which_conv(in_channels, out_channels,
439
+ kernel_size=1, padding=0)
440
+
441
+ def shortcut(self, x):
442
+ if self.preactivation:
443
+ if self.learnable_sc:
444
+ x = self.conv_sc(x)
445
+ if self.downsample:
446
+ x = self.downsample(x)
447
+ else:
448
+ if self.downsample:
449
+ x = self.downsample(x)
450
+ if self.learnable_sc:
451
+ x = self.conv_sc(x)
452
+ return x
453
+
454
+ def forward(self, x):
455
+ if self.preactivation:
456
+ # h = self.activation(x) # NOT TODAY SATAN
457
+ # Andy's note: This line *must* be an out-of-place ReLU or it
458
+ # will negatively affect the shortcut connection.
459
+ h = F.relu(x)
460
+ else:
461
+ h = x
462
+ h = self.conv1(h)
463
+ h = self.conv2(self.activation(h))
464
+ if self.downsample:
465
+ h = self.downsample(h)
466
+
467
+ return h + self.shortcut(x)
468
+
469
+ # dogball
models/BigGAN_networks.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ import functools
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import random
11
+
12
+ from util.augmentations import ProgressiveWordCrop, CycleWordCrop, StaticWordCrop, RandomWordCrop
13
+ from . import BigGAN_layers as layers
14
+ from .networks import init_weights
15
+ import torchvision
16
+ # Attention is passed in in the format '32_64' to mean applying an attention
17
+ # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
18
+
19
+ from models.blocks import Conv2dBlock, ResBlocks
20
+
21
+
22
+ # Discriminator architecture, same paradigm as G's above
23
+ def D_arch(ch=64, attention='64', input_nc=3, ksize='333333', dilation='111111'):
24
+ arch = {}
25
+ arch[256] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
26
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
27
+ 'downsample': [True] * 6 + [False],
28
+ 'resolution': [128, 64, 32, 16, 8, 4, 4],
29
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
30
+ for i in range(2, 8)}}
31
+ arch[128] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
32
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
33
+ 'downsample': [True] * 5 + [False],
34
+ 'resolution': [64, 32, 16, 8, 4, 4],
35
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
36
+ for i in range(2, 8)}}
37
+ arch[64] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
38
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
39
+ 'downsample': [True] * 4 + [False],
40
+ 'resolution': [32, 16, 8, 4, 4],
41
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
42
+ for i in range(2, 7)}}
43
+ arch[63] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
44
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
45
+ 'downsample': [True] * 4 + [False],
46
+ 'resolution': [32, 16, 8, 4, 4],
47
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
48
+ for i in range(2, 7)}}
49
+ arch[32] = {'in_channels': [input_nc] + [item * ch for item in [4, 4, 4]],
50
+ 'out_channels': [item * ch for item in [4, 4, 4, 4]],
51
+ 'downsample': [True, True, False, False],
52
+ 'resolution': [16, 16, 16, 16],
53
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
54
+ for i in range(2, 6)}}
55
+ arch[129] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
56
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
57
+ 'downsample': [True] * 6 + [False],
58
+ 'resolution': [128, 64, 32, 16, 8, 4, 4],
59
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
60
+ for i in range(2, 8)}}
61
+ arch[33] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
62
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
63
+ 'downsample': [True] * 5 + [False],
64
+ 'resolution': [64, 32, 16, 8, 4, 4],
65
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
66
+ for i in range(2, 10)}}
67
+ arch[31] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
68
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
69
+ 'downsample': [True] * 5 + [False],
70
+ 'resolution': [64, 32, 16, 8, 4, 4],
71
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
72
+ for i in range(2, 10)}}
73
+ arch[16] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
74
+ 'out_channels': [item * ch for item in [1, 8, 16, 16]],
75
+ 'downsample': [True] * 3 + [False],
76
+ 'resolution': [16, 8, 4, 4],
77
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
78
+ for i in range(2, 5)}}
79
+
80
+ arch[17] = {'in_channels': [input_nc] + [ch * item for item in [1, 4]],
81
+ 'out_channels': [item * ch for item in [1, 4, 8]],
82
+ 'downsample': [True] * 3,
83
+ 'resolution': [16, 8, 4],
84
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
85
+ for i in range(2, 5)}}
86
+
87
+
88
+ arch[20] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
89
+ 'out_channels': [item * ch for item in [1, 8, 16, 16]],
90
+ 'downsample': [True] * 3 + [False],
91
+ 'resolution': [16, 8, 4, 4],
92
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
93
+ for i in range(2, 5)}}
94
+ return arch
95
+
96
+
97
+ class Discriminator(nn.Module):
98
+
99
+ def __init__(self, resolution, D_ch=64, D_wide=True, D_kernel_size=3, D_attn='64',
100
+ num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
101
+ SN_eps=1e-8, output_dim=1, D_mixed_precision=False, D_fp16=False,
102
+ D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, crop_size: list = None, **kwargs):
103
+
104
+ super(Discriminator, self).__init__()
105
+ self.crop = crop_size is not None and len(crop_size) > 0
106
+
107
+ use_padding = False
108
+
109
+ if self.crop:
110
+ w_crop = StaticWordCrop(crop_size[0], use_padding=use_padding) if len(crop_size) == 1 else RandomWordCrop(crop_size[0], crop_size[1], use_padding=use_padding)
111
+
112
+ self.augmenter = w_crop
113
+
114
+ self.name = 'D'
115
+ # gpu_ids
116
+ self.gpu_ids = gpu_ids
117
+ # one_hot representation
118
+ self.one_hot = one_hot
119
+ # Width multiplier
120
+ self.ch = D_ch
121
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
122
+ self.D_wide = D_wide
123
+ # Resolution
124
+ self.resolution = resolution
125
+ # Kernel size
126
+ self.kernel_size = D_kernel_size
127
+ # Attention?
128
+ self.attention = D_attn
129
+ # Activation
130
+ self.activation = D_activation
131
+ # Initialization style
132
+ self.init = D_init
133
+ # Parameterization style
134
+ self.D_param = D_param
135
+ # Epsilon for Spectral Norm?
136
+ self.SN_eps = SN_eps
137
+ # Fp16?
138
+ self.fp16 = D_fp16
139
+ # Architecture
140
+ self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
141
+
142
+ # Which convs, batchnorms, and linear layers to use
143
+ # No option to turn off SN in D right now
144
+ if self.D_param == 'SN':
145
+ self.which_conv = functools.partial(layers.SNConv2d,
146
+ kernel_size=3, padding=1,
147
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
148
+ eps=self.SN_eps)
149
+ self.which_linear = functools.partial(layers.SNLinear,
150
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
151
+ eps=self.SN_eps)
152
+ self.which_embedding = functools.partial(layers.SNEmbedding,
153
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
154
+ eps=self.SN_eps)
155
+ if bn_linear=='SN':
156
+ self.which_embedding = functools.partial(layers.SNLinear,
157
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
158
+ eps=self.SN_eps)
159
+ else:
160
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
161
+ self.which_linear = nn.Linear
162
+ # We use a non-spectral-normed embedding here regardless;
163
+ # For some reason applying SN to G's embedding seems to randomly cripple G
164
+ self.which_embedding = nn.Embedding
165
+ if one_hot:
166
+ self.which_embedding = functools.partial(layers.SNLinear,
167
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
168
+ eps=self.SN_eps)
169
+ # Prepare model
170
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
171
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
172
+ self.blocks = []
173
+ for index in range(len(self.arch['out_channels'])):
174
+ self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
175
+ out_channels=self.arch['out_channels'][index],
176
+ which_conv=self.which_conv,
177
+ wide=self.D_wide,
178
+ activation=self.activation,
179
+ preactivation=(index > 0),
180
+ downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
181
+ # If attention on this block, attach it to the end
182
+ if self.arch['attention'][self.arch['resolution'][index]]:
183
+ print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
184
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
185
+ self.which_conv)]
186
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
187
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
188
+ # Linear output layer. The output dimension is typically 1, but may be
189
+ # larger if we're e.g. turning this into a VAE with an inference output
190
+ self.dropout = torch.nn.Dropout(p=0.5)
191
+ self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
192
+
193
+ # Initialize weights
194
+ if not skip_init:
195
+ self = init_weights(self, D_init)
196
+
197
+ def update_parameters(self, epoch: int):
198
+ if self.crop:
199
+ self.augmenter.update(epoch)
200
+
201
+ def forward(self, x, y=None, **kwargs):
202
+ # Stick x into h for cleaner for loops without flow control
203
+ if self.crop and random.uniform(0.0, 1.0) < 0.33:
204
+ x = self.augmenter(x)
205
+
206
+ #imgs = [np.squeeze((img.detach().cpu().numpy() + 1.0) / 2.0) for img in x]
207
+ #imgs = (np.vstack(imgs) * 255.0).astype(np.uint8)
208
+ #cv2.imwrite(f"saved_images/debug/{random.randint(0, 1000)}.jpg", imgs)
209
+
210
+ h = x
211
+ # Loop over blocks
212
+ for index, blocklist in enumerate(self.blocks):
213
+ for block in blocklist:
214
+ h = block(h)
215
+
216
+ # Apply global sum pooling as in SN-GAN
217
+ h = torch.sum(self.activation(h), [2, 3])
218
+ out = self.linear(h)
219
+
220
+ return out
221
+
222
+ def return_features(self, x, y=None):
223
+ # Stick x into h for cleaner for loops without flow control
224
+ h = x
225
+ block_output = []
226
+ # Loop over blocks
227
+ for index, blocklist in enumerate(self.blocks):
228
+ for block in blocklist:
229
+ h = block(h)
230
+ block_output.append(h)
231
+ # Apply global sum pooling as in SN-GAN
232
+ # h = torch.sum(self.activation(h), [2, 3])
233
+ return block_output
234
+
235
+
236
+ class WDiscriminator(nn.Module):
237
+
238
+ def __init__(self, resolution, n_classes, output_dim, D_ch=64, D_wide=True, D_kernel_size=3, D_attn='64',
239
+ num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
240
+ SN_eps=1e-8, D_mixed_precision=False, D_fp16=False,
241
+ D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False):
242
+ super(WDiscriminator, self).__init__()
243
+
244
+ self.name = 'D'
245
+ # gpu_ids
246
+ self.gpu_ids = gpu_ids
247
+ # one_hot representation
248
+ self.one_hot = one_hot
249
+ # Width multiplier
250
+ self.ch = D_ch
251
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
252
+ self.D_wide = D_wide
253
+ # Resolution
254
+ self.resolution = resolution
255
+ # Kernel size
256
+ self.kernel_size = D_kernel_size
257
+ # Attention?
258
+ self.attention = D_attn
259
+ # Number of classes
260
+ self.n_classes = n_classes
261
+ # Activation
262
+ self.activation = D_activation
263
+ # Initialization style
264
+ self.init = D_init
265
+ # Parameterization style
266
+ self.D_param = D_param
267
+ # Epsilon for Spectral Norm?
268
+ self.SN_eps = SN_eps
269
+ # Fp16?
270
+ self.fp16 = D_fp16
271
+ # Architecture
272
+ self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
273
+
274
+ # Which convs, batchnorms, and linear layers to use
275
+ # No option to turn off SN in D right now
276
+ if self.D_param == 'SN':
277
+ self.which_conv = functools.partial(layers.SNConv2d,
278
+ kernel_size=3, padding=1,
279
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
280
+ eps=self.SN_eps)
281
+ self.which_linear = functools.partial(layers.SNLinear,
282
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
283
+ eps=self.SN_eps)
284
+ self.which_embedding = functools.partial(layers.SNEmbedding,
285
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
286
+ eps=self.SN_eps)
287
+ if bn_linear == 'SN':
288
+ self.which_embedding = functools.partial(layers.SNLinear,
289
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
290
+ eps=self.SN_eps)
291
+ else:
292
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
293
+ self.which_linear = nn.Linear
294
+ # We use a non-spectral-normed embedding here regardless;
295
+ # For some reason applying SN to G's embedding seems to randomly cripple G
296
+ self.which_embedding = nn.Embedding
297
+ if one_hot:
298
+ self.which_embedding = functools.partial(layers.SNLinear,
299
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
300
+ eps=self.SN_eps)
301
+ # Prepare model
302
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
303
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
304
+ self.blocks = []
305
+ for index in range(len(self.arch['out_channels'])):
306
+ self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
307
+ out_channels=self.arch['out_channels'][index],
308
+ which_conv=self.which_conv,
309
+ wide=self.D_wide,
310
+ activation=self.activation,
311
+ preactivation=(index > 0),
312
+ downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
313
+ # If attention on this block, attach it to the end
314
+ if self.arch['attention'][self.arch['resolution'][index]]:
315
+ print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
316
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
317
+ self.which_conv)]
318
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
319
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
320
+ # Linear output layer. The output dimension is typically 1, but may be
321
+ # larger if we're e.g. turning this into a VAE with an inference output
322
+ self.dropout = torch.nn.Dropout(p=0.5)
323
+ self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
324
+ # Embedding for projection discrimination
325
+ self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
326
+ self.cross_entropy = nn.CrossEntropyLoss()
327
+ # Initialize weights
328
+ if not skip_init:
329
+ self = init_weights(self, D_init)
330
+
331
+ def update_parameters(self, epoch: int):
332
+ pass
333
+
334
+ def forward(self, x, y=None, **kwargs):
335
+ # Stick x into h for cleaner for loops without flow control
336
+ h = x
337
+ # Loop over blocks
338
+ for index, blocklist in enumerate(self.blocks):
339
+ for block in blocklist:
340
+ h = block(h)
341
+ # Apply global sum pooling as in SN-GAN
342
+ h = torch.sum(self.activation(h), [2, 3])
343
+
344
+ # Get initial class-unconditional output
345
+ out = self.linear(h)
346
+ # Get projection of final featureset onto class vectors and add to evidence
347
+ #if y is not None:
348
+ loss = self.cross_entropy(out, y.long())
349
+ return loss
350
+
351
+ def return_features(self, x, y=None):
352
+ # Stick x into h for cleaner for loops without flow control
353
+ h = x
354
+ block_output = []
355
+ # Loop over blocks
356
+ for index, blocklist in enumerate(self.blocks):
357
+ for block in blocklist:
358
+ h = block(h)
359
+ block_output.append(h)
360
+ # Apply global sum pooling as in SN-GAN
361
+ # h = torch.sum(self.activation(h), [2, 3])
362
+ return block_output
363
+
364
+
365
+ class Encoder(Discriminator):
366
+ def __init__(self, opt, output_dim, **kwargs):
367
+ super(Encoder, self).__init__(**vars(opt))
368
+ self.output_layer = nn.Sequential(self.activation,
369
+ nn.Conv2d(self.arch['out_channels'][-1], output_dim, kernel_size=(4,2), padding=0, stride=2))
370
+
371
+ def forward(self, x):
372
+ # Stick x into h for cleaner for loops without flow control
373
+ h = x
374
+ # Loop over blocks
375
+ for index, blocklist in enumerate(self.blocks):
376
+ for block in blocklist:
377
+ h = block(h)
378
+ out = self.output_layer(h)
379
+ return out
models/OCR_network.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .networks import *
3
+
4
+
5
+ class BidirectionalLSTM(nn.Module):
6
+
7
+ def __init__(self, nIn, nHidden, nOut):
8
+ super(BidirectionalLSTM, self).__init__()
9
+
10
+ self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
11
+ self.embedding = nn.Linear(nHidden * 2, nOut)
12
+
13
+
14
+ def forward(self, input):
15
+ recurrent, _ = self.rnn(input)
16
+ T, b, h = recurrent.size()
17
+ t_rec = recurrent.view(T * b, h)
18
+
19
+ output = self.embedding(t_rec) # [T * b, nOut]
20
+ output = output.view(T, b, -1)
21
+
22
+ return output
23
+
24
+
25
+ class CRNN(nn.Module):
26
+
27
+ def __init__(self, args, leakyRelu=False):
28
+ super(CRNN, self).__init__()
29
+ self.args = args
30
+ self.name = 'OCR'
31
+ self.add_noise = False
32
+ self.noise_fac = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([0.2]))
33
+ #assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16'
34
+
35
+ ks = [3, 3, 3, 3, 3, 3, 2]
36
+ ps = [1, 1, 1, 1, 1, 1, 0]
37
+ ss = [1, 1, 1, 1, 1, 1, 1]
38
+ nm = [64, 128, 256, 256, 512, 512, 512]
39
+
40
+ cnn = nn.Sequential()
41
+ nh = 256
42
+ dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero
43
+
44
+ def convRelu(i, batchNormalization=False):
45
+ nIn = 1 if i == 0 else nm[i - 1]
46
+ nOut = nm[i]
47
+ cnn.add_module('conv{0}'.format(i),
48
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
49
+ if batchNormalization:
50
+ cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
51
+ if leakyRelu:
52
+ cnn.add_module('relu{0}'.format(i),
53
+ nn.LeakyReLU(0.2, inplace=True))
54
+ else:
55
+ cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
56
+
57
+ convRelu(0)
58
+ cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
59
+ convRelu(1)
60
+ cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
61
+ convRelu(2, True)
62
+ convRelu(3)
63
+ cnn.add_module('pooling{0}'.format(2),
64
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
65
+ convRelu(4, True)
66
+ if self.args.resolution==63:
67
+ cnn.add_module('pooling{0}'.format(3),
68
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
69
+ convRelu(5)
70
+ cnn.add_module('pooling{0}'.format(4),
71
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
72
+ convRelu(6, True) # 512x1x16
73
+
74
+ self.cnn = cnn
75
+ self.use_rnn = False
76
+ if self.use_rnn:
77
+ self.rnn = nn.Sequential(
78
+ BidirectionalLSTM(512, nh, nh),
79
+ BidirectionalLSTM(nh, nh, ))
80
+ else:
81
+ self.linear = nn.Linear(512, self.args.vocab_size)
82
+
83
+ # replace all nan/inf in gradients to zero
84
+ if dealwith_lossnone:
85
+ self.register_backward_hook(self.backward_hook)
86
+
87
+ self.device = torch.device('cuda:{}'.format(0))
88
+ self.init = 'N02'
89
+ # Initialize weights
90
+
91
+ self = init_weights(self, self.init)
92
+
93
+ def forward(self, input):
94
+ # conv features
95
+ if self.add_noise:
96
+ input = input + self.noise_fac.sample(input.size()).squeeze(-1).to(self.args.device)
97
+ conv = self.cnn(input)
98
+ b, c, h, w = conv.size()
99
+ if h!=1:
100
+ print('a')
101
+ assert h == 1, "the height of conv must be 1"
102
+ conv = conv.squeeze(2)
103
+ conv = conv.permute(2, 0, 1) # [w, b, c]
104
+
105
+ if self.use_rnn:
106
+ # rnn features
107
+ output = self.rnn(conv)
108
+ else:
109
+ output = self.linear(conv)
110
+ return output
111
+
112
+ def backward_hook(self, module, grad_input, grad_output):
113
+ for g in grad_input:
114
+ g[g != g] = 0 # replace all nan/inf in gradients to zero
115
+
116
+
117
+ class strLabelConverter(object):
118
+ """Convert between str and label.
119
+ NOTE:
120
+ Insert `blank` to the alphabet for CTC.
121
+ Args:
122
+ alphabet (str): set of the possible characters.
123
+ ignore_case (bool, default=True): whether or not to ignore all of the case.
124
+ """
125
+
126
+ def __init__(self, alphabet, ignore_case=False):
127
+ self._ignore_case = ignore_case
128
+ if self._ignore_case:
129
+ alphabet = alphabet.lower()
130
+ self.alphabet = alphabet + '-' # for `-1` index
131
+
132
+ self.dict = {}
133
+ for i, char in enumerate(alphabet):
134
+ # NOTE: 0 is reserved for 'blank' required by wrap_ctc
135
+ self.dict[char] = i + 1
136
+
137
+ def encode(self, text):
138
+ """Support batch or single str.
139
+ Args:
140
+ text (str or list of str): texts to convert.
141
+ Returns:
142
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
143
+ torch.IntTensor [n]: length of each text.
144
+ """
145
+ length = []
146
+ result = []
147
+ results = []
148
+ for item in text:
149
+ if isinstance(item, bytes): item = item.decode('utf-8', 'strict')
150
+ length.append(len(item))
151
+ for char in item:
152
+ index = self.dict[char]
153
+ result.append(index)
154
+ results.append(result)
155
+ result = []
156
+
157
+ return torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length), None
158
+
159
+ def decode(self, t, length, raw=False):
160
+ """Decode encoded texts back into strs.
161
+ Args:
162
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
163
+ torch.IntTensor [n]: length of each text.
164
+ Raises:
165
+ AssertionError: when the texts and its length does not match.
166
+ Returns:
167
+ text (str or list of str): texts to convert.
168
+ """
169
+ if length.numel() == 1:
170
+ length = length[0]
171
+ assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
172
+ length)
173
+ if raw:
174
+ return ''.join([self.alphabet[i - 1] for i in t])
175
+ else:
176
+ char_list = []
177
+ for i in range(length):
178
+ if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
179
+ char_list.append(self.alphabet[t[i] - 1])
180
+ return ''.join(char_list)
181
+ else:
182
+ # batch mode
183
+ assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
184
+ t.numel(), length.sum())
185
+ texts = []
186
+ index = 0
187
+ for i in range(length.numel()):
188
+ l = length[i]
189
+ texts.append(
190
+ self.decode(
191
+ t[index:index + l], torch.IntTensor([l]), raw=raw))
192
+ index += l
193
+ return texts
models/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ """
19
+
20
+ import importlib
21
+
22
+
23
+ def find_model_using_name(model_name):
24
+ """Import the module "models/[model_name]_model.py".
25
+
26
+ In the file, the class called DatasetNameModel() will
27
+ be instantiated. It has to be a subclass of BaseModel,
28
+ and it is case-insensitive.
29
+ """
30
+ model_filename = "models." + model_name + "_model"
31
+ modellib = importlib.import_module(model_filename)
32
+ model = None
33
+ target_model_name = model_name.replace('_', '') + 'model'
34
+ for name, cls in modellib.__dict__.items():
35
+ if name.lower() == target_model_name.lower() \
36
+ and issubclass(cls, BaseModel):
37
+ model = cls
38
+
39
+ if model is None:
40
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
41
+ exit(0)
42
+
43
+ return model
44
+
45
+
46
+ def get_option_setter(model_name):
47
+ """Return the static method <modify_commandline_options> of the model class."""
48
+ model_class = find_model_using_name(model_name)
49
+ return model_class.modify_commandline_options
50
+
51
+
52
+ def create_model(opt):
53
+ """Create a model given the option.
54
+
55
+ This function warps the class CustomDatasetDataLoader.
56
+ This is the main interface between this package and 'train.py'/'test.py'
57
+
58
+ Example:
59
+ >>> from models import create_model
60
+ >>> model = create_model(opt)
61
+ """
62
+ model = find_model_using_name(opt.model)
63
+ instance = model(opt)
64
+ print("model [%s] was created" % type(instance).__name__)
65
+ return instance
models/blocks.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class ResBlocks(nn.Module):
7
+ def __init__(self, num_blocks, dim, norm, activation, pad_type):
8
+ super(ResBlocks, self).__init__()
9
+ self.model = []
10
+ for i in range(num_blocks):
11
+ self.model += [ResBlock(dim,
12
+ norm=norm,
13
+ activation=activation,
14
+ pad_type=pad_type)]
15
+ self.model = nn.Sequential(*self.model)
16
+
17
+ def forward(self, x):
18
+ return self.model(x)
19
+
20
+
21
+ class ResBlock(nn.Module):
22
+ def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
23
+ super(ResBlock, self).__init__()
24
+ model = []
25
+ model += [Conv2dBlock(dim, dim, 3, 1, 1,
26
+ norm=norm,
27
+ activation=activation,
28
+ pad_type=pad_type)]
29
+ model += [Conv2dBlock(dim, dim, 3, 1, 1,
30
+ norm=norm,
31
+ activation='none',
32
+ pad_type=pad_type)]
33
+ self.model = nn.Sequential(*model)
34
+
35
+ def forward(self, x):
36
+ residual = x
37
+ out = self.model(x)
38
+ out += residual
39
+ return out
40
+
41
+
42
+ class ActFirstResBlock(nn.Module):
43
+ def __init__(self, fin, fout, fhid=None,
44
+ activation='lrelu', norm='none'):
45
+ super().__init__()
46
+ self.learned_shortcut = (fin != fout)
47
+ self.fin = fin
48
+ self.fout = fout
49
+ self.fhid = min(fin, fout) if fhid is None else fhid
50
+ self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1,
51
+ padding=1, pad_type='reflect', norm=norm,
52
+ activation=activation, activation_first=True)
53
+ self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1,
54
+ padding=1, pad_type='reflect', norm=norm,
55
+ activation=activation, activation_first=True)
56
+ if self.learned_shortcut:
57
+ self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1,
58
+ activation='none', use_bias=False)
59
+
60
+ def forward(self, x):
61
+ x_s = self.conv_s(x) if self.learned_shortcut else x
62
+ dx = self.conv_0(x)
63
+ dx = self.conv_1(dx)
64
+ out = x_s + dx
65
+ return out
66
+
67
+
68
+ class LinearBlock(nn.Module):
69
+ def __init__(self, in_dim, out_dim, norm='none', activation='relu'):
70
+ super(LinearBlock, self).__init__()
71
+ use_bias = True
72
+ self.fc = nn.Linear(in_dim, out_dim, bias=use_bias)
73
+
74
+ # initialize normalization
75
+ norm_dim = out_dim
76
+ if norm == 'bn':
77
+ self.norm = nn.BatchNorm1d(norm_dim)
78
+ elif norm == 'in':
79
+ self.norm = nn.InstanceNorm1d(norm_dim)
80
+ elif norm == 'none':
81
+ self.norm = None
82
+ else:
83
+ assert 0, "Unsupported normalization: {}".format(norm)
84
+
85
+ # initialize activation
86
+ if activation == 'relu':
87
+ self.activation = nn.ReLU(inplace=False)
88
+ elif activation == 'lrelu':
89
+ self.activation = nn.LeakyReLU(0.2, inplace=False)
90
+ elif activation == 'tanh':
91
+ self.activation = nn.Tanh()
92
+ elif activation == 'none':
93
+ self.activation = None
94
+ else:
95
+ assert 0, "Unsupported activation: {}".format(activation)
96
+
97
+ def forward(self, x):
98
+ out = self.fc(x)
99
+ if self.norm:
100
+ out = self.norm(out)
101
+ if self.activation:
102
+ out = self.activation(out)
103
+ return out
104
+
105
+
106
+ class Conv2dBlock(nn.Module):
107
+ def __init__(self, in_dim, out_dim, ks, st, padding=0,
108
+ norm='none', activation='relu', pad_type='zero',
109
+ use_bias=True, activation_first=False):
110
+ super(Conv2dBlock, self).__init__()
111
+ self.use_bias = use_bias
112
+ self.activation_first = activation_first
113
+ # initialize padding
114
+ if pad_type == 'reflect':
115
+ self.pad = nn.ReflectionPad2d(padding)
116
+ elif pad_type == 'replicate':
117
+ self.pad = nn.ReplicationPad2d(padding)
118
+ elif pad_type == 'zero':
119
+ self.pad = nn.ZeroPad2d(padding)
120
+ else:
121
+ assert 0, "Unsupported padding type: {}".format(pad_type)
122
+
123
+ # initialize normalization
124
+ norm_dim = out_dim
125
+ if norm == 'bn':
126
+ self.norm = nn.BatchNorm2d(norm_dim)
127
+ elif norm == 'in':
128
+ self.norm = nn.InstanceNorm2d(norm_dim)
129
+ elif norm == 'adain':
130
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
131
+ elif norm == 'none':
132
+ self.norm = None
133
+ else:
134
+ assert 0, "Unsupported normalization: {}".format(norm)
135
+
136
+ # initialize activation
137
+ if activation == 'relu':
138
+ self.activation = nn.ReLU(inplace=False)
139
+ elif activation == 'lrelu':
140
+ self.activation = nn.LeakyReLU(0.2, inplace=False)
141
+ elif activation == 'tanh':
142
+ self.activation = nn.Tanh()
143
+ elif activation == 'none':
144
+ self.activation = None
145
+ else:
146
+ assert 0, "Unsupported activation: {}".format(activation)
147
+
148
+ self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias)
149
+
150
+ def forward(self, x):
151
+ if self.activation_first:
152
+ if self.activation:
153
+ x = self.activation(x)
154
+ x = self.conv(self.pad(x))
155
+ if self.norm:
156
+ x = self.norm(x)
157
+ else:
158
+ x = self.conv(self.pad(x))
159
+ if self.norm:
160
+ x = self.norm(x)
161
+ if self.activation:
162
+ x = self.activation(x)
163
+ return x
164
+
165
+
166
+ class AdaptiveInstanceNorm2d(nn.Module):
167
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
168
+ super(AdaptiveInstanceNorm2d, self).__init__()
169
+ self.num_features = num_features
170
+ self.eps = eps
171
+ self.momentum = momentum
172
+ self.weight = None
173
+ self.bias = None
174
+ self.register_buffer('running_mean', torch.zeros(num_features))
175
+ self.register_buffer('running_var', torch.ones(num_features))
176
+
177
+ def forward(self, x):
178
+ assert self.weight is not None and \
179
+ self.bias is not None, "Please assign AdaIN weight first"
180
+ b, c = x.size(0), x.size(1)
181
+ running_mean = self.running_mean.repeat(b)
182
+ running_var = self.running_var.repeat(b)
183
+ x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
184
+ out = F.batch_norm(
185
+ x_reshaped, running_mean, running_var, self.weight, self.bias,
186
+ True, self.momentum, self.eps)
187
+ return out.view(b, c, *x.size()[2:])
188
+
189
+ def __repr__(self):
190
+ return self.__class__.__name__ + '(' + str(self.num_features) + ')'
models/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tn_hidden_dim = 512
2
+ tn_dropout = 0.1
3
+ tn_nheads = 8
4
+ tn_dim_feedforward = 512
5
+ tn_enc_layers = 3
6
+ tn_dec_layers = 3
models/inception.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=[DEFAULT_BLOCK_INDEX],
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = models.inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def fid_inception_v3():
167
+ """Build pretrained Inception model for FID computation
168
+
169
+ The Inception model for FID computation uses a different set of weights
170
+ and has a slightly different structure than torchvision's Inception.
171
+
172
+ This method first constructs torchvision's Inception and then patches the
173
+ necessary parts that are different in the FID Inception model.
174
+ """
175
+ inception = models.inception_v3(num_classes=1008,
176
+ aux_logits=False,
177
+ weights=None,
178
+ init_weights=False)
179
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
180
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
181
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
182
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
183
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
184
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
185
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
186
+ inception.Mixed_7b = FIDInceptionE_1(1280)
187
+ inception.Mixed_7c = FIDInceptionE_2(2048)
188
+
189
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
190
+ inception.load_state_dict(state_dict)
191
+ return inception
192
+
193
+
194
+ class FIDInceptionA(models.inception.InceptionA):
195
+ """InceptionA block patched for FID computation"""
196
+ def __init__(self, in_channels, pool_features):
197
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
198
+
199
+ def forward(self, x):
200
+ branch1x1 = self.branch1x1(x)
201
+
202
+ branch5x5 = self.branch5x5_1(x)
203
+ branch5x5 = self.branch5x5_2(branch5x5)
204
+
205
+ branch3x3dbl = self.branch3x3dbl_1(x)
206
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
207
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
208
+
209
+ # Patch: Tensorflow's average pool does not use the padded zero's in
210
+ # its average calculation
211
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
212
+ count_include_pad=False)
213
+ branch_pool = self.branch_pool(branch_pool)
214
+
215
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
216
+ return torch.cat(outputs, 1)
217
+
218
+
219
+ class FIDInceptionC(models.inception.InceptionC):
220
+ """InceptionC block patched for FID computation"""
221
+ def __init__(self, in_channels, channels_7x7):
222
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
223
+
224
+ def forward(self, x):
225
+ branch1x1 = self.branch1x1(x)
226
+
227
+ branch7x7 = self.branch7x7_1(x)
228
+ branch7x7 = self.branch7x7_2(branch7x7)
229
+ branch7x7 = self.branch7x7_3(branch7x7)
230
+
231
+ branch7x7dbl = self.branch7x7dbl_1(x)
232
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
233
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
234
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
235
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
236
+
237
+ # Patch: Tensorflow's average pool does not use the padded zero's in
238
+ # its average calculation
239
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
240
+ count_include_pad=False)
241
+ branch_pool = self.branch_pool(branch_pool)
242
+
243
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
244
+ return torch.cat(outputs, 1)
245
+
246
+
247
+ class FIDInceptionE_1(models.inception.InceptionE):
248
+ """First InceptionE block patched for FID computation"""
249
+ def __init__(self, in_channels):
250
+ super(FIDInceptionE_1, self).__init__(in_channels)
251
+
252
+ def forward(self, x):
253
+ branch1x1 = self.branch1x1(x)
254
+
255
+ branch3x3 = self.branch3x3_1(x)
256
+ branch3x3 = [
257
+ self.branch3x3_2a(branch3x3),
258
+ self.branch3x3_2b(branch3x3),
259
+ ]
260
+ branch3x3 = torch.cat(branch3x3, 1)
261
+
262
+ branch3x3dbl = self.branch3x3dbl_1(x)
263
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
264
+ branch3x3dbl = [
265
+ self.branch3x3dbl_3a(branch3x3dbl),
266
+ self.branch3x3dbl_3b(branch3x3dbl),
267
+ ]
268
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
269
+
270
+ # Patch: Tensorflow's average pool does not use the padded zero's in
271
+ # its average calculation
272
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
273
+ count_include_pad=False)
274
+ branch_pool = self.branch_pool(branch_pool)
275
+
276
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
277
+ return torch.cat(outputs, 1)
278
+
279
+
280
+ class FIDInceptionE_2(models.inception.InceptionE):
281
+ """Second InceptionE block patched for FID computation"""
282
+ def __init__(self, in_channels):
283
+ super(FIDInceptionE_2, self).__init__(in_channels)
284
+
285
+ def forward(self, x):
286
+ branch1x1 = self.branch1x1(x)
287
+
288
+ branch3x3 = self.branch3x3_1(x)
289
+ branch3x3 = [
290
+ self.branch3x3_2a(branch3x3),
291
+ self.branch3x3_2b(branch3x3),
292
+ ]
293
+ branch3x3 = torch.cat(branch3x3, 1)
294
+
295
+ branch3x3dbl = self.branch3x3dbl_1(x)
296
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
297
+ branch3x3dbl = [
298
+ self.branch3x3dbl_3a(branch3x3dbl),
299
+ self.branch3x3dbl_3b(branch3x3dbl),
300
+ ]
301
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
302
+
303
+ # Patch: The FID Inception model uses max pooling instead of average
304
+ # pooling. This is likely an error in this specific Inception
305
+ # implementation, as other Inception models use average pooling here
306
+ # (which matches the description in the paper).
307
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
308
+ branch_pool = self.branch_pool(branch_pool)
309
+
310
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
311
+ return torch.cat(outputs, 1)
models/model.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data
2
+ from torch.nn import CTCLoss
3
+ from torch.nn.utils import clip_grad_norm_
4
+ import sys
5
+ import torchvision.models as models
6
+
7
+ from models.inception import InceptionV3
8
+ from models.transformer import *
9
+ from util.augmentations import OCRAugment
10
+ from util.misc import SmoothedValue
11
+ from util.text import get_generator, AugmentedGenerator
12
+ from .BigGAN_networks import *
13
+ from .OCR_network import *
14
+ from models.blocks import Conv2dBlock, ResBlocks
15
+ from util.util import loss_hinge_dis, loss_hinge_gen, make_one_hot
16
+
17
+ import models.config as config
18
+ from .positional_encodings import PositionalEncoding1D
19
+ from models.unifont_module import UnifontModule
20
+ from PIL import Image
21
+
22
+
23
+ def get_rgb(x):
24
+ R = 255 - int(int(x > 0.5) * 255 * (x - 0.5) / 0.5)
25
+ G = 0
26
+ B = 255 + int(int(x < 0.5) * 255 * (x - 0.5) / 0.5)
27
+ return R, G, B
28
+
29
+
30
+ def get_page_from_words(word_lists, MAX_IMG_WIDTH=800):
31
+ line_all = []
32
+ line_t = []
33
+
34
+ width_t = 0
35
+
36
+ for i in word_lists:
37
+
38
+ width_t = width_t + i.shape[1] + 16
39
+
40
+ if width_t > MAX_IMG_WIDTH:
41
+ line_all.append(np.concatenate(line_t, 1))
42
+
43
+ line_t = []
44
+
45
+ width_t = i.shape[1] + 16
46
+
47
+ line_t.append(i)
48
+ line_t.append(np.ones((i.shape[0], 16)))
49
+
50
+ if len(line_all) == 0:
51
+ line_all.append(np.concatenate(line_t, 1))
52
+
53
+ max_lin_widths = MAX_IMG_WIDTH # max([i.shape[1] for i in line_all])
54
+ gap_h = np.ones([16, max_lin_widths])
55
+
56
+ page_ = []
57
+
58
+ for l in line_all:
59
+ pad_ = np.ones([l.shape[0], max_lin_widths - l.shape[1]])
60
+
61
+ page_.append(np.concatenate([l, pad_], 1))
62
+ page_.append(gap_h)
63
+
64
+ page = np.concatenate(page_, 0)
65
+
66
+ return page * 255
67
+
68
+
69
+ class FCNDecoder(nn.Module):
70
+ def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
71
+ super(FCNDecoder, self).__init__()
72
+
73
+ self.model = []
74
+ self.model += [ResBlocks(n_res, dim, res_norm,
75
+ activ, pad_type=pad_type)]
76
+ for i in range(ups):
77
+ self.model += [nn.Upsample(scale_factor=2),
78
+ Conv2dBlock(dim, dim // 2, 5, 1, 2,
79
+ norm='in',
80
+ activation=activ,
81
+ pad_type=pad_type)]
82
+ dim //= 2
83
+ self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
84
+ norm='none',
85
+ activation='tanh',
86
+ pad_type=pad_type)]
87
+ self.model = nn.Sequential(*self.model)
88
+
89
+ def forward(self, x):
90
+ y = self.model(x)
91
+
92
+ return y
93
+
94
+
95
+ class Generator(nn.Module):
96
+
97
+ def __init__(self, args):
98
+ super(Generator, self).__init__()
99
+ self.args = args
100
+ INP_CHANNEL = 1
101
+
102
+ encoder_layer = TransformerEncoderLayer(config.tn_hidden_dim, config.tn_nheads,
103
+ config.tn_dim_feedforward,
104
+ config.tn_dropout, "relu", True)
105
+ encoder_norm = nn.LayerNorm(config.tn_hidden_dim) if True else None
106
+ self.encoder = TransformerEncoder(encoder_layer, config.tn_enc_layers, encoder_norm)
107
+
108
+ decoder_layer = TransformerDecoderLayer(config.tn_hidden_dim, config.tn_nheads,
109
+ config.tn_dim_feedforward,
110
+ config.tn_dropout, "relu", True)
111
+ decoder_norm = nn.LayerNorm(config.tn_hidden_dim)
112
+ self.decoder = TransformerDecoder(decoder_layer, config.tn_dec_layers, decoder_norm,
113
+ return_intermediate=True)
114
+
115
+ self.Feat_Encoder = models.resnet18(weights='ResNet18_Weights.DEFAULT')
116
+ self.Feat_Encoder.conv1 = nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)
117
+ self.Feat_Encoder.fc = nn.Identity()
118
+ self.Feat_Encoder.avgpool = nn.Identity()
119
+
120
+ # self.query_embed = nn.Embedding(self.args.vocab_size, self.args.tn_hidden_dim)
121
+ self.query_embed = UnifontModule(
122
+ config.tn_dim_feedforward,
123
+ self.args.alphabet + self.args.special_alphabet,
124
+ input_type=self.args.query_input,
125
+ device=self.args.device
126
+ )
127
+
128
+ self.pos_encoder = PositionalEncoding1D(config.tn_hidden_dim)
129
+
130
+ self.linear_q = nn.Linear(config.tn_dim_feedforward, config.tn_dim_feedforward * 8)
131
+
132
+ self.DEC = FCNDecoder(res_norm='in', dim=config.tn_hidden_dim)
133
+
134
+ self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
135
+
136
+ def evaluate(self, style_images, queries):
137
+ style = self.compute_style(style_images)
138
+
139
+ results = []
140
+
141
+ for i in range(queries.shape[1]):
142
+ query = queries[:, i, :]
143
+ h = self.generate(style, query)
144
+
145
+ results.append(h.detach())
146
+
147
+ return results
148
+
149
+ def compute_style(self, style_images):
150
+ B, N, R, C = style_images.shape
151
+ FEAT_ST = self.Feat_Encoder(style_images.view(B * N, 1, R, C))
152
+ FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
153
+ FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2, 0, 1)
154
+ memory = self.encoder(FEAT_ST_ENC)
155
+ return memory
156
+
157
+ def generate(self, style_vector, query):
158
+ query_embed = self.query_embed(query).permute(1, 0, 2)
159
+
160
+ tgt = torch.zeros_like(query_embed)
161
+ hs = self.decoder(tgt, style_vector, query_pos=query_embed)
162
+
163
+ h = hs.transpose(1, 2)[-1]
164
+
165
+ if self.args.add_noise:
166
+ h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
167
+
168
+ h = self.linear_q(h)
169
+ h = h.contiguous()
170
+
171
+ h = h.view(h.size(0), h.shape[1] * 2, 4, -1)
172
+ h = h.permute(0, 3, 2, 1)
173
+
174
+ h = self.DEC(h)
175
+
176
+ return h
177
+
178
+ def forward(self, style_images, query):
179
+ enc_attn_weights, dec_attn_weights = [], []
180
+
181
+ self.hooks = [
182
+
183
+ self.encoder.layers[-1].self_attn.register_forward_hook(
184
+ lambda self, input, output: enc_attn_weights.append(output[1])
185
+ ),
186
+ self.decoder.layers[-1].multihead_attn.register_forward_hook(
187
+ lambda self, input, output: dec_attn_weights.append(output[1])
188
+ ),
189
+ ]
190
+
191
+ style = self.compute_style(style_images)
192
+
193
+ h = self.generate(style, query)
194
+
195
+ self.dec_attn_weights = dec_attn_weights[-1].detach()
196
+ self.enc_attn_weights = enc_attn_weights[-1].detach()
197
+
198
+ for hook in self.hooks:
199
+ hook.remove()
200
+
201
+ return h, style
202
+
203
+
204
+ class VATr(nn.Module):
205
+
206
+ def __init__(self, args):
207
+ super(VATr, self).__init__()
208
+ self.args = args
209
+ self.args.vocab_size = len(args.alphabet)
210
+
211
+ self.epsilon = 1e-7
212
+ self.netG = Generator(self.args).to(self.args.device)
213
+ self.netD = Discriminator(
214
+ resolution=self.args.resolution, crop_size=args.d_crop_size,
215
+ ).to(self.args.device)
216
+
217
+ self.netW = WDiscriminator(resolution=self.args.resolution, n_classes=self.args.vocab_size, output_dim=self.args.num_writers)
218
+ self.netW = self.netW.to(self.args.device)
219
+ self.netconverter = strLabelConverter(self.args.alphabet + self.args.special_alphabet)
220
+
221
+ self.netOCR = CRNN(self.args).to(self.args.device)
222
+
223
+ self.ocr_augmenter = OCRAugment(prob=0.5, no=3)
224
+ self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
225
+
226
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
227
+ self.inception = InceptionV3([block_idx]).to(self.args.device)
228
+
229
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
230
+ lr=self.args.g_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
231
+
232
+ self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
233
+ lr=self.args.ocr_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
234
+
235
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
236
+ lr=self.args.d_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
237
+
238
+ self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
239
+ lr=self.args.w_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
240
+
241
+ self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
242
+
243
+ self.optimizer_G.zero_grad()
244
+ self.optimizer_OCR.zero_grad()
245
+ self.optimizer_D.zero_grad()
246
+ self.optimizer_wl.zero_grad()
247
+
248
+ self.loss_G = 0
249
+ self.loss_D = 0
250
+ self.loss_Dfake = 0
251
+ self.loss_Dreal = 0
252
+ self.loss_OCR_fake = 0
253
+ self.loss_OCR_real = 0
254
+ self.loss_w_fake = 0
255
+ self.loss_w_real = 0
256
+ self.Lcycle = 0
257
+ self.d_acc = SmoothedValue()
258
+
259
+ self.word_generator = get_generator(args)
260
+
261
+ self.epoch = 0
262
+
263
+ with open('mytext.txt', 'r', encoding='utf-8') as f:
264
+ self.text = f.read()
265
+ self.text = self.text.replace('\n', ' ')
266
+ self.text = self.text.replace('\n', ' ')
267
+ self.text = ''.join(c for c in self.text if c in (self.args.alphabet + self.args.special_alphabet)) # just to avoid problems with the font dataset
268
+ self.text = [word.encode() for word in self.text.split()] # [:args.num_examples]
269
+
270
+ self.eval_text_encode, self.eval_len_text, self.eval_encode_pos = self.netconverter.encode(self.text)
271
+ self.eval_text_encode = self.eval_text_encode.to(self.args.device).repeat(self.args.batch_size, 1, 1)
272
+
273
+ self.rv_sample_size = 64 * 4
274
+ self.last_fakes = []
275
+
276
+ def update_last_fakes(self, fakes):
277
+ for fake in fakes:
278
+ self.last_fakes.append(fake)
279
+ self.last_fakes = self.last_fakes[-self.rv_sample_size:]
280
+
281
+ def update_acc(self, pred_real, pred_fake):
282
+ correct = (pred_real >= 0.5).float().sum() + (pred_fake < 0.5).float().sum()
283
+ self.d_acc.update(correct / (len(pred_real) + len(pred_fake)))
284
+
285
+ def set_text_aug_strength(self, strength):
286
+ if not isinstance(self.word_generator, AugmentedGenerator):
287
+ print("WARNING: Text generator is not augmented, strength cannot be set")
288
+ else:
289
+ self.word_generator.set_strength(strength)
290
+
291
+ def get_text_aug_strength(self):
292
+ if isinstance(self.word_generator, AugmentedGenerator):
293
+ return self.word_generator.strength
294
+ else:
295
+ return 0.0
296
+
297
+ def update_parameters(self, epoch: int):
298
+ self.epoch = epoch
299
+ self.netD.update_parameters(epoch)
300
+ self.netW.update_parameters(epoch)
301
+
302
+ def get_text_sample(self, size: int) -> list:
303
+ return [self.word_generator.generate() for _ in range(size)]
304
+
305
+ def _generate_fakes(self, ST, eval_text_encode=None, eval_len_text=None):
306
+ if eval_text_encode == None:
307
+ eval_text_encode = self.eval_text_encode
308
+ if eval_len_text == None:
309
+ eval_len_text = self.eval_len_text
310
+
311
+ self.fakes = self.netG.evaluate(ST, eval_text_encode)
312
+
313
+ np_fakes = []
314
+ for batch_idx in range(self.fakes[0].shape[0]):
315
+ for idx, fake in enumerate(self.fakes):
316
+ fake = fake[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution]
317
+ fake = (fake + 1) / 2
318
+ np_fakes.append(fake.cpu().numpy())
319
+ return np_fakes
320
+
321
+ def _generate_page(self, ST, SLEN, eval_text_encode=None, eval_len_text=None, eval_encode_pos=None, lwidth=260, rwidth=980):
322
+ # ST -> Style?
323
+
324
+ if eval_text_encode == None:
325
+ eval_text_encode = self.eval_text_encode
326
+ if eval_len_text == None:
327
+ eval_len_text = self.eval_len_text
328
+ if eval_encode_pos is None:
329
+ eval_encode_pos = self.eval_encode_pos
330
+
331
+ text_encode, text_len, _ = self.netconverter.encode(self.args.special_alphabet)
332
+ symbols = self.netG.query_embed.symbols[text_encode].reshape(-1, 16, 16).cpu().numpy()
333
+ imgs = [Image.fromarray(s).resize((32, 32), resample=0) for s in symbols]
334
+ special_examples = 1 - np.concatenate([np.array(i) for i in imgs], axis=-1)
335
+
336
+ self.fakes = self.netG.evaluate(ST, eval_text_encode)
337
+
338
+ page1s = []
339
+ page2s = []
340
+
341
+ for batch_idx in range(ST.shape[0]):
342
+
343
+ word_t = []
344
+ word_l = []
345
+
346
+ gap = np.ones([self.args.img_height, 16])
347
+
348
+ line_wids = []
349
+
350
+ for idx, fake_ in enumerate(self.fakes):
351
+
352
+ word_t.append((fake_[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution].cpu().numpy() + 1) / 2)
353
+
354
+ word_t.append(gap)
355
+
356
+ if sum(t.shape[-1] for t in word_t) >= rwidth or idx == len(self.fakes) - 1 or (len(self.fakes) - len(self.args.special_alphabet) - 1) == idx:
357
+ line_ = np.concatenate(word_t, -1)
358
+
359
+ word_l.append(line_)
360
+ line_wids.append(line_.shape[1])
361
+
362
+ word_t = []
363
+
364
+ # add the examples from the UnifontModules
365
+ word_l.append(special_examples)
366
+ line_wids.append(special_examples.shape[1])
367
+
368
+ gap_h = np.ones([16, max(line_wids)])
369
+
370
+ page_ = []
371
+
372
+ for l in word_l:
373
+ pad_ = np.ones([self.args.img_height, max(line_wids) - l.shape[1]])
374
+
375
+ page_.append(np.concatenate([l, pad_], 1))
376
+ page_.append(gap_h)
377
+
378
+ page1 = np.concatenate(page_, 0)
379
+
380
+ word_t = []
381
+ word_l = []
382
+
383
+
384
+ line_wids = []
385
+
386
+ sdata_ = [i.unsqueeze(1) for i in torch.unbind(ST, 1)]
387
+ gap = np.ones([sdata_[0].shape[-2], 16])
388
+
389
+ for idx, st in enumerate((sdata_)):
390
+
391
+ word_t.append((st[batch_idx, 0, :, :int(SLEN.cpu().numpy()[batch_idx][idx])].cpu().numpy() + 1) / 2)
392
+ # word_t.append((st[batch_idx, 0, :, :].cpu().numpy() + 1) / 2)
393
+
394
+ word_t.append(gap)
395
+
396
+ if sum(t.shape[-1] for t in word_t) >= lwidth or idx == len(sdata_) - 1:
397
+ line_ = np.concatenate(word_t, -1)
398
+
399
+ word_l.append(line_)
400
+ line_wids.append(line_.shape[1])
401
+
402
+ word_t = []
403
+
404
+ gap_h = np.ones([16, max(line_wids)])
405
+
406
+ page_ = []
407
+
408
+ for l in word_l:
409
+ pad_ = np.ones([sdata_[0].shape[-2], max(line_wids) - l.shape[1]])
410
+
411
+ page_.append(np.concatenate([l, pad_], 1))
412
+ page_.append(gap_h)
413
+
414
+ page2 = np.concatenate(page_, 0)
415
+
416
+ merge_w_size = max(page1.shape[0], page2.shape[0])
417
+
418
+ if page1.shape[0] != merge_w_size:
419
+ page1 = np.concatenate([page1, np.ones([merge_w_size - page1.shape[0], page1.shape[1]])], 0)
420
+
421
+ if page2.shape[0] != merge_w_size:
422
+ page2 = np.concatenate([page2, np.ones([merge_w_size - page2.shape[0], page2.shape[1]])], 0)
423
+
424
+ page1s.append(page1)
425
+ page2s.append(page2)
426
+
427
+ # page = np.concatenate([page2, page1], 1)
428
+
429
+ page1s_ = np.concatenate(page1s, 0)
430
+ max_wid = max([i.shape[1] for i in page2s])
431
+ padded_page2s = []
432
+
433
+ for para in page2s:
434
+ padded_page2s.append(np.concatenate([para, np.ones([para.shape[0], max_wid - para.shape[1]])], 1))
435
+
436
+ padded_page2s_ = np.concatenate(padded_page2s, 0)
437
+
438
+ return np.concatenate([padded_page2s_, page1s_], 1)
439
+
440
+ def get_current_losses(self):
441
+
442
+ losses = {}
443
+
444
+ losses['G'] = self.loss_G
445
+ losses['D'] = self.loss_D
446
+ losses['Dfake'] = self.loss_Dfake
447
+ losses['Dreal'] = self.loss_Dreal
448
+ losses['OCR_fake'] = self.loss_OCR_fake
449
+ losses['OCR_real'] = self.loss_OCR_real
450
+ losses['w_fake'] = self.loss_w_fake
451
+ losses['w_real'] = self.loss_w_real
452
+ losses['cycle'] = self.Lcycle
453
+
454
+ return losses
455
+
456
+ def _set_input(self, input):
457
+ self.input = input
458
+
459
+ self.real = self.input['img'].to(self.args.device)
460
+ self.label = self.input['label']
461
+
462
+ self.set_ocr_data(self.input['img'], self.input['label'])
463
+
464
+ self.sdata = self.input['simg'].to(self.args.device)
465
+ self.slabels = self.input['slabels']
466
+
467
+ self.ST_LEN = self.input['swids']
468
+
469
+ def set_requires_grad(self, nets, requires_grad=False):
470
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
471
+ Parameters:
472
+ nets (network list) -- a list of networks
473
+ requires_grad (bool) -- whether the networks require gradients or not
474
+ """
475
+ if not isinstance(nets, list):
476
+ nets = [nets]
477
+ for net in nets:
478
+ if net is not None:
479
+ for param in net.parameters():
480
+ param.requires_grad = requires_grad
481
+
482
+ def forward(self):
483
+ self.text_encode, self.len_text, self.encode_pos = self.netconverter.encode(self.label)
484
+ self.text_encode = self.text_encode.to(self.args.device).detach()
485
+ self.len_text = self.len_text.detach()
486
+
487
+ self.words = [self.word_generator.generate().encode('utf-8') for _ in range(self.args.batch_size)]
488
+ self.text_encode_fake, self.len_text_fake, self.encode_pos_fake = self.netconverter.encode(self.words)
489
+ self.text_encode_fake = self.text_encode_fake.to(self.args.device)
490
+ self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.args.vocab_size).to(
491
+ self.args.device)
492
+
493
+ self.fake, self.style = self.netG(self.sdata, self.text_encode_fake)
494
+
495
+ self.update_last_fakes(self.fake)
496
+
497
+ def pad_width(self, t, new_width):
498
+ result = torch.ones((t.size(0), t.size(1), t.size(2), new_width), device=t.device)
499
+ result[:,:,:,:t.size(-1)] = t
500
+
501
+ return result
502
+
503
+ def compute_real_ocr_loss(self, ocr_network = None):
504
+ network = ocr_network if ocr_network is not None else self.netOCR
505
+ real_input = self.ocr_images
506
+ input_images = real_input
507
+ input_labels = self.ocr_labels
508
+
509
+ input_images = input_images.detach()
510
+
511
+ if self.ocr_augmenter is not None:
512
+ input_images = self.ocr_augmenter(input_images)
513
+
514
+ pred_real = network(input_images)
515
+ preds_size = torch.IntTensor([pred_real.size(0)] * len(input_labels)).detach()
516
+ text_encode, len_text, _ = self.netconverter.encode(input_labels)
517
+
518
+ loss = self.OCR_criterion(pred_real, text_encode.detach(), preds_size, len_text.detach())
519
+
520
+ return torch.mean(loss[~torch.isnan(loss)])
521
+
522
+ def compute_fake_ocr_loss(self, ocr_network = None):
523
+ network = ocr_network if ocr_network is not None else self.netOCR
524
+
525
+ pred_fake_OCR = network(self.fake)
526
+ preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.args.batch_size).detach()
527
+ loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size,
528
+ self.len_text_fake.detach())
529
+ return torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
530
+
531
+ def set_ocr_data(self, images, labels):
532
+ self.ocr_images = images.to(self.args.device)
533
+ self.ocr_labels = labels
534
+
535
+ def backward_D_OCR(self):
536
+ self.real.__repr__()
537
+ self.fake.__repr__()
538
+ pred_real = self.netD(self.real.detach())
539
+ pred_fake = self.netD(**{'x': self.fake.detach()})
540
+
541
+ self.update_acc(pred_real, pred_fake)
542
+
543
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(),
544
+ self.len_text.detach(), True)
545
+
546
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
547
+
548
+ if not self.args.no_ocr_loss:
549
+ self.loss_OCR_real = self.compute_real_ocr_loss()
550
+ loss_total = self.loss_D + self.loss_OCR_real
551
+ else:
552
+ loss_total = self.loss_D
553
+
554
+ # backward
555
+ loss_total.backward()
556
+ if not self.args.no_ocr_loss:
557
+ self.clean_grad(self.netOCR.parameters())
558
+
559
+ return loss_total
560
+
561
+ def clean_grad(self, params):
562
+ for param in params:
563
+ param.grad[param.grad != param.grad] = 0
564
+ param.grad[torch.isnan(param.grad)] = 0
565
+ param.grad[torch.isinf(param.grad)] = 0
566
+
567
+ def backward_D_WL(self):
568
+ # Real
569
+ pred_real = self.netD(self.real.detach())
570
+
571
+ pred_fake = self.netD(**{'x': self.fake.detach()})
572
+
573
+ self.update_acc(pred_real, pred_fake)
574
+
575
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(),
576
+ self.len_text.detach(), True)
577
+
578
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
579
+
580
+ if not self.args.no_writer_loss:
581
+ self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(self.args.device)).mean()
582
+ # total loss
583
+ loss_total = self.loss_D + self.loss_w_real * self.args.writer_loss_weight
584
+ else:
585
+ loss_total = self.loss_D
586
+
587
+ # backward
588
+ loss_total.backward()
589
+
590
+ return loss_total
591
+
592
+ def optimize_D_WL(self):
593
+ self.forward()
594
+ self.set_requires_grad([self.netD], True)
595
+ self.set_requires_grad([self.netOCR], False)
596
+ self.set_requires_grad([self.netW], True)
597
+ self.set_requires_grad([self.netW], True)
598
+
599
+ self.optimizer_D.zero_grad()
600
+ self.optimizer_wl.zero_grad()
601
+
602
+ self.backward_D_WL()
603
+
604
+ def optimize_D_WL_step(self):
605
+ self.optimizer_D.step()
606
+ self.optimizer_wl.step()
607
+ self.optimizer_D.zero_grad()
608
+ self.optimizer_wl.zero_grad()
609
+
610
+ def compute_cycle_loss(self):
611
+ fake_input = torch.ones_like(self.sdata)
612
+ width = min(self.sdata.size(-1), self.fake.size(-1))
613
+ fake_input[:, :, :, :width] = self.fake.repeat(1, 15, 1, 1)[:, :, :, :width]
614
+ with torch.no_grad():
615
+ fake_style = self.netG.compute_style(fake_input)
616
+
617
+ return torch.sum(torch.abs(self.style.detach() - fake_style), dim=1).mean()
618
+
619
+ def backward_G_only(self):
620
+ self.gb_alpha = 0.7
621
+ if self.args.is_cycle:
622
+ self.Lcycle = self.compute_cycle_loss()
623
+
624
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
625
+
626
+ compute_ocr = not self.args.no_ocr_loss
627
+
628
+ if compute_ocr:
629
+ self.loss_OCR_fake = self.compute_fake_ocr_loss()
630
+
631
+ self.loss_G = self.loss_G + self.Lcycle
632
+
633
+ if compute_ocr:
634
+ self.loss_T = self.loss_G + self.loss_OCR_fake
635
+ else:
636
+ self.loss_T = self.loss_G
637
+
638
+ if compute_ocr:
639
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
640
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
641
+
642
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
643
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
644
+
645
+ self.loss_T.backward(retain_graph=True)
646
+
647
+ if compute_ocr:
648
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
649
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
650
+ a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR))
651
+ self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
652
+ self.loss_T = self.loss_G + self.loss_OCR_fake
653
+ else:
654
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
655
+ a = 1
656
+ self.loss_T = self.loss_G
657
+
658
+ if a is None:
659
+ print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv))
660
+ if a > 1000 or a < 0.0001:
661
+ print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
662
+
663
+ self.loss_T.backward(retain_graph=True)
664
+ if compute_ocr:
665
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
666
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
667
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
668
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
669
+
670
+ with torch.no_grad():
671
+ self.loss_T.backward()
672
+ if compute_ocr:
673
+ if any(torch.isnan(torch.unsqueeze(self.loss_OCR_fake, dim=0))) or torch.isnan(self.loss_G):
674
+ print('loss OCR fake: ', self.loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
675
+ sys.exit()
676
+
677
+ def backward_G_WL(self):
678
+ self.gb_alpha = 0.7
679
+ if self.args.is_cycle:
680
+ self.Lcycle = self.compute_cycle_loss()
681
+
682
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
683
+
684
+ if not self.args.no_writer_loss:
685
+ self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(self.args.device)).mean()
686
+
687
+ self.loss_G = self.loss_G + self.Lcycle
688
+
689
+ if not self.args.no_writer_loss:
690
+ self.loss_T = self.loss_G + self.loss_w_fake * self.args.writer_loss_weight
691
+ else:
692
+ self.loss_T = self.loss_G
693
+
694
+ self.loss_T.backward(retain_graph=True)
695
+
696
+ if not self.args.no_writer_loss:
697
+ grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
698
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
699
+ a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_WL))
700
+ self.loss_w_fake = a.detach() * self.loss_w_fake
701
+ self.loss_T = self.loss_G + self.loss_w_fake
702
+ else:
703
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
704
+ a = 1
705
+ self.loss_T = self.loss_G
706
+
707
+ if a is None:
708
+ print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv))
709
+ if a > 1000 or a < 0.0001:
710
+ print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
711
+
712
+ self.loss_T.backward(retain_graph=True)
713
+
714
+ if not self.args.no_writer_loss:
715
+ grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
716
+ self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
717
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
718
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
719
+
720
+ with torch.no_grad():
721
+ self.loss_T.backward()
722
+
723
+ def backward_G(self):
724
+ self.opt.gb_alpha = 0.7
725
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(),
726
+ self.opt.mask_loss)
727
+ # OCR loss on real data
728
+ compute_ocr = not self.args.no_ocr_loss
729
+
730
+ if compute_ocr:
731
+ self.loss_OCR_fake = self.compute_fake_ocr_loss()
732
+ else:
733
+ self.loss_OCR_fake = 0.0
734
+
735
+ self.loss_w_fake = self.netW(self.fake, self.wcl)
736
+ # self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
737
+ # total loss
738
+
739
+ # l1 = self.params[0]*self.loss_G
740
+ # l2 = self.params[0]*self.loss_OCR_fake
741
+ # l3 = self.params[0]*self.loss_w_fake
742
+ self.loss_G_ = 10 * self.loss_G + self.loss_w_fake
743
+ self.loss_T = self.loss_G_ + self.loss_OCR_fake
744
+
745
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
746
+
747
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
748
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
749
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
750
+
751
+ if not False:
752
+
753
+ self.loss_T.backward(retain_graph=True)
754
+
755
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
756
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
757
+ # grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
758
+
759
+ a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR))
760
+
761
+ # a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
762
+
763
+ if a is None:
764
+ print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
765
+ if a > 1000 or a < 0.0001:
766
+ print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
767
+ b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
768
+ torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR)) *
769
+ torch.mean(grad_fake_OCR))
770
+ # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
771
+ self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
772
+ # self.loss_w_fake = a0.detach() * self.loss_w_fake
773
+
774
+ self.loss_T = (1 - 1 * self.opt.onlyOCR) * self.loss_G_ + self.loss_OCR_fake # + self.loss_w_fake
775
+ self.loss_T.backward(retain_graph=True)
776
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
777
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
778
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
779
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
780
+ with torch.no_grad():
781
+ self.loss_T.backward()
782
+ else:
783
+ self.loss_T.backward()
784
+
785
+ if self.opt.clip_grad > 0:
786
+ clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
787
+ if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
788
+ print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
789
+ sys.exit()
790
+
791
+ def optimize_D_OCR(self):
792
+ self.forward()
793
+ self.set_requires_grad([self.netD], True)
794
+ self.set_requires_grad([self.netOCR], True)
795
+ self.optimizer_D.zero_grad()
796
+ # if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
797
+ self.optimizer_OCR.zero_grad()
798
+ self.backward_D_OCR()
799
+
800
+ def optimize_D_OCR_step(self):
801
+ self.optimizer_D.step()
802
+
803
+ self.optimizer_OCR.step()
804
+ self.optimizer_D.zero_grad()
805
+ self.optimizer_OCR.zero_grad()
806
+
807
+ def optimize_G_WL(self):
808
+ self.forward()
809
+ self.set_requires_grad([self.netD], False)
810
+ self.set_requires_grad([self.netOCR], False)
811
+ self.set_requires_grad([self.netW], False)
812
+ self.backward_G_WL()
813
+
814
+ def optimize_G_only(self):
815
+ self.forward()
816
+ self.set_requires_grad([self.netD], False)
817
+ self.set_requires_grad([self.netOCR], False)
818
+ self.set_requires_grad([self.netW], False)
819
+ self.backward_G_only()
820
+
821
+ def optimize_G_step(self):
822
+ self.optimizer_G.step()
823
+ self.optimizer_G.zero_grad()
824
+
825
+ def save_networks(self, epoch, save_dir):
826
+ """Save all the networks to the disk.
827
+
828
+ Parameters:
829
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
830
+ """
831
+ for name in self.model_names:
832
+ if isinstance(name, str):
833
+ save_filename = '%s_net_%s.pth' % (epoch, name)
834
+ save_path = os.path.join(save_dir, save_filename)
835
+ net = getattr(self, 'net' + name)
836
+
837
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
838
+ # torch.save(net.module.cpu().state_dict(), save_path)
839
+ if len(self.gpu_ids) > 1:
840
+ torch.save(net.module.cpu().state_dict(), save_path)
841
+ else:
842
+ torch.save(net.cpu().state_dict(), save_path)
843
+ net.cuda(self.gpu_ids[0])
844
+ else:
845
+ torch.save(net.cpu().state_dict(), save_path)
846
+
847
+ def compute_d_scores(self, data_loader: torch.utils.data.DataLoader, amount: int = None):
848
+ scores = []
849
+ words = []
850
+ amount = len(data_loader) if amount is None else amount // data_loader.batch_size
851
+
852
+ with torch.no_grad():
853
+ for i in range(amount):
854
+ data = next(iter(data_loader))
855
+ words.extend([d.decode() for d in data['label']])
856
+ scores.extend(list(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy()))
857
+
858
+ return scores, words
859
+
860
+ def compute_d_scores_fake(self, data_loader: torch.utils.data.DataLoader, amount: int = None):
861
+ scores = []
862
+ words = []
863
+ amount = len(data_loader) if amount is None else amount // data_loader.batch_size
864
+
865
+ with torch.no_grad():
866
+ for i in range(amount):
867
+ data = next(iter(data_loader))
868
+ to_generate = [self.word_generator.generate().encode('utf-8') for _ in range(data_loader.batch_size)]
869
+ text_encode_fake, len_text_fake, encode_pos_fake = self.netconverter.encode(to_generate)
870
+ fake, _ = self.netG(data['simg'].to(self.args.device), text_encode_fake.to(self.args.device))
871
+
872
+ words.extend([d.decode() for d in to_generate])
873
+ scores.extend(list(self.netD(fake).squeeze().detach().cpu().numpy()))
874
+
875
+ return scores, words
876
+
877
+ def compute_d_stats(self, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader):
878
+ train_values = []
879
+ val_values = []
880
+ fake_values = []
881
+ with torch.no_grad():
882
+ for i in range(self.rv_sample_size // train_loader.batch_size):
883
+ data = next(iter(train_loader))
884
+ train_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())
885
+
886
+ for i in range(self.rv_sample_size // val_loader.batch_size):
887
+ data = next(iter(val_loader))
888
+ val_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())
889
+
890
+ for i in range(self.rv_sample_size):
891
+ data = self.last_fakes[i]
892
+ fake_values.append(self.netD(data.unsqueeze(0)).squeeze().detach().cpu().numpy())
893
+
894
+ return np.mean(train_values), np.mean(val_values), np.mean(fake_values)
models/networks.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ from util.util import to_device, load_network
7
+
8
+ ###############################################################################
9
+ # Helper Functions
10
+ ###############################################################################
11
+
12
+
13
+ def init_weights(net, init_type='normal', init_gain=0.02):
14
+ """Initialize network weights.
15
+
16
+ Parameters:
17
+ net (network) -- network to be initialized
18
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
19
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
20
+
21
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
22
+ work better for some applications. Feel free to try yourself.
23
+ """
24
+ def init_func(m): # define the initialization function
25
+ classname = m.__class__.__name__
26
+ if (isinstance(m, nn.Conv2d)
27
+ or isinstance(m, nn.Linear)
28
+ or isinstance(m, nn.Embedding)):
29
+ # if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
+ if init_type == 'N02':
31
+ init.normal_(m.weight.data, 0.0, init_gain)
32
+ elif init_type in ['glorot', 'xavier']:
33
+ init.xavier_normal_(m.weight.data, gain=init_gain)
34
+ elif init_type == 'kaiming':
35
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
+ elif init_type == 'ortho':
37
+ init.orthogonal_(m.weight.data, gain=init_gain)
38
+ else:
39
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
40
+ # if hasattr(m, 'bias') and m.bias is not None:
41
+ # init.constant_(m.bias.data, 0.0)
42
+ # elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
43
+ # init.normal_(m.weight.data, 1.0, init_gain)
44
+ # init.constant_(m.bias.data, 0.0)
45
+ if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']:
46
+ # print('initialize network with %s' % init_type)
47
+ net.apply(init_func) # apply the initialization function <init_func>
48
+ else:
49
+ # print('loading the model from %s' % init_type)
50
+ net = load_network(net, init_type, 'latest')
51
+ return net
52
+
53
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
54
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
55
+ Parameters:
56
+ net (network) -- the network to be initialized
57
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
58
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
59
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
60
+
61
+ Return an initialized network.
62
+ """
63
+ if len(gpu_ids) > 0:
64
+ assert(torch.cuda.is_available())
65
+ net.to(gpu_ids[0])
66
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
67
+ init_weights(net, init_type, init_gain=init_gain)
68
+ return net
69
+
70
+
71
+ def get_scheduler(optimizer, opt):
72
+ """Return a learning rate scheduler
73
+
74
+ Parameters:
75
+ optimizer -- the optimizer of the network
76
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
77
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
78
+
79
+ For 'linear', we keep the same learning rate for the first <opt.niter> epochs
80
+ and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
81
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
82
+ See https://pytorch.org/docs/stable/optim.html for more details.
83
+ """
84
+ if opt.lr_policy == 'linear':
85
+ def lambda_rule(epoch):
86
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
87
+ return lr_l
88
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
89
+ elif opt.lr_policy == 'step':
90
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
91
+ elif opt.lr_policy == 'plateau':
92
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
93
+ elif opt.lr_policy == 'cosine':
94
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
95
+ else:
96
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
97
+ return scheduler
98
+
models/positional_encodings.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def get_emb(sin_inp):
7
+ """
8
+ Gets a base embedding for one dimension with sin and cos intertwined
9
+ """
10
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
11
+ return torch.flatten(emb, -2, -1)
12
+
13
+
14
+ class PositionalEncoding1D(nn.Module):
15
+ def __init__(self, channels):
16
+ """
17
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
18
+ """
19
+ super(PositionalEncoding1D, self).__init__()
20
+ self.org_channels = channels
21
+ channels = int(np.ceil(channels / 2) * 2)
22
+ self.channels = channels
23
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
24
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
25
+ self.cached_penc = None
26
+
27
+ def forward(self, tensor):
28
+ """
29
+ :param tensor: A 3d tensor of size (batch_size, x, ch)
30
+ :return: Positional Encoding Matrix of size (batch_size, x, ch)
31
+ """
32
+ if len(tensor.shape) != 3:
33
+ raise RuntimeError("The input tensor has to be 3d!")
34
+
35
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
36
+ return self.cached_penc
37
+
38
+ self.cached_penc = None
39
+ batch_size, x, orig_ch = tensor.shape
40
+ pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
41
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
42
+ emb_x = get_emb(sin_inp_x)
43
+ emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
44
+ emb[:, : self.channels] = emb_x
45
+
46
+ self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
47
+ return self.cached_penc
48
+
49
+
50
+ class PositionalEncodingPermute1D(nn.Module):
51
+ def __init__(self, channels):
52
+ """
53
+ Accepts (batchsize, ch, x) instead of (batchsize, x, ch)
54
+ """
55
+ super(PositionalEncodingPermute1D, self).__init__()
56
+ self.penc = PositionalEncoding1D(channels)
57
+
58
+ def forward(self, tensor):
59
+ tensor = tensor.permute(0, 2, 1)
60
+ enc = self.penc(tensor)
61
+ return enc.permute(0, 2, 1)
62
+
63
+ @property
64
+ def org_channels(self):
65
+ return self.penc.org_channels
66
+
67
+
68
+ class PositionalEncoding2D(nn.Module):
69
+ def __init__(self, channels):
70
+ """
71
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
72
+ """
73
+ super(PositionalEncoding2D, self).__init__()
74
+ self.org_channels = channels
75
+ channels = int(np.ceil(channels / 4) * 2)
76
+ self.channels = channels
77
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
78
+ self.register_buffer("inv_freq", inv_freq)
79
+ self.cached_penc = None
80
+
81
+ def forward(self, tensor):
82
+ """
83
+ :param tensor: A 4d tensor of size (batch_size, x, y, ch)
84
+ :return: Positional Encoding Matrix of size (batch_size, x, y, ch)
85
+ """
86
+ if len(tensor.shape) != 4:
87
+ raise RuntimeError("The input tensor has to be 4d!")
88
+
89
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
90
+ return self.cached_penc
91
+
92
+ self.cached_penc = None
93
+ batch_size, x, y, orig_ch = tensor.shape
94
+ pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
95
+ pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
96
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
97
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
98
+ emb_x = get_emb(sin_inp_x).unsqueeze(1)
99
+ emb_y = get_emb(sin_inp_y)
100
+ emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
101
+ tensor.type()
102
+ )
103
+ emb[:, :, : self.channels] = emb_x
104
+ emb[:, :, self.channels : 2 * self.channels] = emb_y
105
+
106
+ self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
107
+ return self.cached_penc
108
+
109
+
110
+ class PositionalEncodingPermute2D(nn.Module):
111
+ def __init__(self, channels):
112
+ """
113
+ Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch)
114
+ """
115
+ super(PositionalEncodingPermute2D, self).__init__()
116
+ self.penc = PositionalEncoding2D(channels)
117
+
118
+ def forward(self, tensor):
119
+ tensor = tensor.permute(0, 2, 3, 1)
120
+ enc = self.penc(tensor)
121
+ return enc.permute(0, 3, 1, 2)
122
+
123
+ @property
124
+ def org_channels(self):
125
+ return self.penc.org_channels
126
+
127
+
128
+ class PositionalEncoding3D(nn.Module):
129
+ def __init__(self, channels):
130
+ """
131
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
132
+ """
133
+ super(PositionalEncoding3D, self).__init__()
134
+ self.org_channels = channels
135
+ channels = int(np.ceil(channels / 6) * 2)
136
+ if channels % 2:
137
+ channels += 1
138
+ self.channels = channels
139
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
140
+ self.register_buffer("inv_freq", inv_freq)
141
+ self.cached_penc = None
142
+
143
+ def forward(self, tensor):
144
+ """
145
+ :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
146
+ :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
147
+ """
148
+ if len(tensor.shape) != 5:
149
+ raise RuntimeError("The input tensor has to be 5d!")
150
+
151
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
152
+ return self.cached_penc
153
+
154
+ self.cached_penc = None
155
+ batch_size, x, y, z, orig_ch = tensor.shape
156
+ pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
157
+ pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
158
+ pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
159
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
160
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
161
+ sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
162
+ emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
163
+ emb_y = get_emb(sin_inp_y).unsqueeze(1)
164
+ emb_z = get_emb(sin_inp_z)
165
+ emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
166
+ tensor.type()
167
+ )
168
+ emb[:, :, :, : self.channels] = emb_x
169
+ emb[:, :, :, self.channels : 2 * self.channels] = emb_y
170
+ emb[:, :, :, 2 * self.channels :] = emb_z
171
+
172
+ self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)
173
+ return self.cached_penc
174
+
175
+
176
+ class PositionalEncodingPermute3D(nn.Module):
177
+ def __init__(self, channels):
178
+ """
179
+ Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
180
+ """
181
+ super(PositionalEncodingPermute3D, self).__init__()
182
+ self.penc = PositionalEncoding3D(channels)
183
+
184
+ def forward(self, tensor):
185
+ tensor = tensor.permute(0, 2, 3, 4, 1)
186
+ enc = self.penc(tensor)
187
+ return enc.permute(0, 4, 1, 2, 3)
188
+
189
+ @property
190
+ def org_channels(self):
191
+ return self.penc.org_channels
192
+
193
+
194
+ class Summer(nn.Module):
195
+ def __init__(self, penc):
196
+ """
197
+ :param model: The type of positional encoding to run the summer on.
198
+ """
199
+ super(Summer, self).__init__()
200
+ self.penc = penc
201
+
202
+ def forward(self, tensor):
203
+ """
204
+ :param tensor: A 3, 4 or 5d tensor that matches the model output size
205
+ :return: Positional Encoding Matrix summed to the original tensor
206
+ """
207
+ penc = self.penc(tensor)
208
+ assert (
209
+ tensor.size() == penc.size()
210
+ ), "The original tensor size {} and the positional encoding tensor size {} must match!".format(
211
+ tensor.size(), penc.size()
212
+ )
213
+ return tensor + penc
214
+
215
+
216
+ class SparsePositionalEncoding2D(PositionalEncoding2D):
217
+ def __init__(self, channels, x, y, device='cuda'):
218
+ super(SparsePositionalEncoding2D, self).__init__(channels)
219
+ self.y, self.x = y, x
220
+ self.fake_tensor = torch.zeros((1, x, y, channels), device=device)
221
+
222
+ def forward(self, coords):
223
+ """
224
+ :param coords: A list of list of coordinates (((x1, y1), (x2, y22), ... ), ... )
225
+ :return: Positional Encoding Matrix summed to the original tensor
226
+ """
227
+ encodings = super().forward(self.fake_tensor)
228
+ encodings = encodings.permute(0, 3, 1, 2)
229
+ indices = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(c) for c in coords], batch_first=True, padding_value=-1)
230
+ indices = indices.unsqueeze(0).to(self.fake_tensor.device)
231
+ assert self.x == self.y
232
+ indices = (indices + 0.5) / self.x * 2 - 1
233
+ indices = torch.flip(indices, (-1, ))
234
+ return torch.nn.functional.grid_sample(encodings, indices).squeeze().permute(2, 1, 0)
235
+
236
+ # all_encodings = []
237
+ # for coords_row in coords:
238
+ # res_encodings = []
239
+ # for xy in coords_row:
240
+ # if xy is None:
241
+ # res_encodings.append(padding)
242
+ # else:
243
+ # x, y = xy
244
+ # res_encodings.append(encodings[x, y, :])
245
+ # all_encodings.append(res_encodings)
246
+ # return torch.stack(res_encodings).to(self.fake_tensor.device)
247
+
248
+ # coords = torch.Tensor(coords).to(self.fake_tensor.device).long()
249
+ # assert torch.all(coords[:, 0] < self.x)
250
+ # assert torch.all(coords[:, 1] < self.y)
251
+ # coords = coords[:, 0] + (coords[:, 1] * self.x)
252
+ # encodings = super().forward(self.fake_tensor).reshape((-1, self.org_channels))
253
+ # return encodings[coords]
254
+
255
+ if __name__ == '__main__':
256
+ pos = SparsePositionalEncoding2D(10, 10, 20)
257
+ pos([[0, 0], [0, 9], [1, 0], [9, 15]])