vittoriopippi
commited on
Commit
·
fa0f216
1
Parent(s):
434bf7c
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +3 -0
- Groundtruth/gan.iam.test.gt.filter27 +0 -0
- Groundtruth/gan.iam.tr_va.gt.filter27 +0 -0
- README.md +99 -3
- config.json +46 -0
- configuration_vatrpp.py +82 -0
- corpora_english/brown-azAZ.tr +0 -0
- corpora_english/in_vocab.subset.tro.37 +114 -0
- corpora_english/oov.common_words +79 -0
- corpora_english/oov_words.txt +400 -0
- create_style_sample.py +25 -0
- data/create_data.py +469 -0
- data/dataset.py +324 -0
- data/iam_test.py +51 -0
- data/show_dataset.py +149 -0
- files/IAM-32-pa.pickle +3 -0
- files/IAM-32.pickle +3 -0
- files/cvl_model.pth +3 -0
- files/english_words.txt +0 -0
- files/files +1 -0
- files/hwt.pth +3 -0
- files/resnet_18_pretrained.pth +3 -0
- files/unifont.pickle +3 -0
- files/vatr.pth +3 -0
- files/vatrpp.pth +3 -0
- generate.py +49 -0
- generate/__init__.py +5 -0
- generate/authors.py +48 -0
- generate/fid.py +63 -0
- generate/ocr.py +72 -0
- generate/page.py +57 -0
- generate/text.py +24 -0
- generate/util.py +15 -0
- generate/writer.py +329 -0
- generation_config.json +4 -0
- hwt/config.json +46 -0
- hwt/generation_config.json +4 -0
- hwt/model.safetensors +3 -0
- model.safetensors +3 -0
- modeling_vatrpp.py +338 -0
- models/BigGAN_layers.py +469 -0
- models/BigGAN_networks.py +379 -0
- models/OCR_network.py +193 -0
- models/__init__.py +65 -0
- models/blocks.py +190 -0
- models/config.py +6 -0
- models/inception.py +311 -0
- models/model.py +894 -0
- models/networks.py +98 -0
- models/positional_encodings.py +257 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
taylor_swift.png
|
2 |
+
test.py
|
3 |
+
*.pyc
|
Groundtruth/gan.iam.test.gt.filter27
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Groundtruth/gan.iam.tr_va.gt.filter27
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,3 +1,99 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Handwritten Text Generation from Visual Archetypes ++
|
2 |
+
|
3 |
+
This repository includes the code for training the VATr++ Styled Handwritten Text Generation model.
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
|
7 |
+
```bash
|
8 |
+
conda create --name vatr python=3.9
|
9 |
+
conda activate vatr
|
10 |
+
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
|
11 |
+
git clone https://github.com/aimagelab/VATr.git && cd VATr
|
12 |
+
pip install -r requirements.txt
|
13 |
+
```
|
14 |
+
|
15 |
+
[This folder](https://drive.google.com/drive/folders/13rJhjl7VsyiXlPTBvnp1EKkKEhckLalr?usp=sharing) contains the regular IAM dataset `IAM-32.pickle` and the modified version with attached punctuation marks `IAM-32-pa.pickle`.
|
16 |
+
The folder also contains the synthetically pretrained weights for the encoder `resnet_18_pretrained.pth`.
|
17 |
+
Please download these files and place them into the `files` folder.
|
18 |
+
|
19 |
+
## Training
|
20 |
+
|
21 |
+
To train the regular VATr model, use the following command. This uses the default settings from the paper.
|
22 |
+
|
23 |
+
```bash
|
24 |
+
python train.py
|
25 |
+
```
|
26 |
+
|
27 |
+
Useful arguments:
|
28 |
+
```bash
|
29 |
+
python train.py
|
30 |
+
--feat_model_path PATH # path to the pretrained resnet 18 checkpoint. By default this is the synthetically pretrained model
|
31 |
+
--is_cycle # use style cycle loss for training
|
32 |
+
--dataset DATASET # dataset to use. Default IAM
|
33 |
+
--resume # resume training from the last checkpoint with the same name
|
34 |
+
--wandb # use wandb for logging
|
35 |
+
```
|
36 |
+
|
37 |
+
Use the following arguments to apply full VATr++ training
|
38 |
+
```bash
|
39 |
+
python train.py
|
40 |
+
--d-crop-size 64 128 # Randomly crop input to discriminator to width 64 to 128
|
41 |
+
--text-augment-strength 0.4 # Text augmentation for adding more rare characters
|
42 |
+
--file-suffix pa # Use the punctuation attached version of IAM
|
43 |
+
--augment-ocr # Augment the real images used to train the OCR model
|
44 |
+
```
|
45 |
+
|
46 |
+
### Pretraining dataset
|
47 |
+
The model `resnet_18_pretrained.pth` was pretrained by using this dataset: [Font Square](https://github.com/aimagelab/font_square)
|
48 |
+
|
49 |
+
|
50 |
+
## Generate Styled Handwritten Text Images
|
51 |
+
|
52 |
+
We added some utility to generate handwritten text images using the trained model. These are used as follows:
|
53 |
+
|
54 |
+
```bash
|
55 |
+
python generate.py [ACTION] --checkpoint files/vatrpp.pth
|
56 |
+
```
|
57 |
+
|
58 |
+
The following actions are available with their respective arguments.
|
59 |
+
|
60 |
+
### Custom Author
|
61 |
+
|
62 |
+
Generate the given text for a custom author.
|
63 |
+
|
64 |
+
```bash
|
65 |
+
text --text STRING # String to generate
|
66 |
+
--text-path PATH # Optional path to text file
|
67 |
+
--output PATH # Optional output location, default: files/output.png
|
68 |
+
--style-folder PATH # Optional style folder containing writer samples, default: 'files/style_samples/00'
|
69 |
+
```
|
70 |
+
Style samples for the author are needed. These can be automatically generated from an image of a page using `create_style_sample.py`.
|
71 |
+
```bash
|
72 |
+
python create_style_sample.py --input-image PATH # Path of the image to extract the style samples from.
|
73 |
+
--output-folder PATH # Folder where the style samples should be saved
|
74 |
+
```
|
75 |
+
|
76 |
+
### All Authors
|
77 |
+
|
78 |
+
Generate some text for all authors of IAM. The output is saved to `saved_images/author_samples/`
|
79 |
+
|
80 |
+
```bash
|
81 |
+
authors --test-set # Generate authors of test set, otherwise training set is generated
|
82 |
+
--checkpoint PATH # Checkpoint used to generate text, files/vatr.pth by default
|
83 |
+
--align # Detect the bottom lines for each word and align them
|
84 |
+
--at-once # Generate the whole sentence at once instead of word-by-word
|
85 |
+
--output-style # Also save the style images used to generate the words
|
86 |
+
```
|
87 |
+
|
88 |
+
### Evaluation Images
|
89 |
+
|
90 |
+
```bash
|
91 |
+
fid --target_dataset_path PATH # dataset file for which the test set will be generated
|
92 |
+
--dataset-path PATH # dataset file from which style samples will be taken, for example the attached punctuation
|
93 |
+
--output PATH # where to save the images, default is saved_images/fid
|
94 |
+
--checkpoint PATH # Checkpoint used to generate text, files/vatr.pth by default
|
95 |
+
--all-epochs # Generate evaluation images for all saved epochs available (checkpoint has to be a folder)
|
96 |
+
--fake-only # Only output fake images, no ground truth
|
97 |
+
--test-only # Only generate test set, not train set
|
98 |
+
--long-tail # Only generate words containing long tail characters
|
99 |
+
```
|
config.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_noise": false,
|
3 |
+
"alphabet": "Only thewigsofrcvdampbkuq.A-210xT5'MDL,RYHJ\"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%",
|
4 |
+
"architectures": [
|
5 |
+
"VATrPP"
|
6 |
+
],
|
7 |
+
"augment_ocr": false,
|
8 |
+
"batch_size": 8,
|
9 |
+
"corpus": "standard",
|
10 |
+
"d_crop_size": null,
|
11 |
+
"d_lr": 1e-05,
|
12 |
+
"dataset": "IAM",
|
13 |
+
"device": "cuda",
|
14 |
+
"english_words_path": "files/english_words.txt",
|
15 |
+
"epochs": 100000,
|
16 |
+
"feat_model_path": "files/resnet_18_pretrained.pth",
|
17 |
+
"file_suffix": null,
|
18 |
+
"g_lr": 5e-05,
|
19 |
+
"img_height": 32,
|
20 |
+
"is_cycle": false,
|
21 |
+
"label_encoder": "default",
|
22 |
+
"model_type": "emuru",
|
23 |
+
"no_ocr_loss": false,
|
24 |
+
"no_writer_loss": false,
|
25 |
+
"num_examples": 15,
|
26 |
+
"num_words": 3,
|
27 |
+
"num_workers": 0,
|
28 |
+
"num_writers": 339,
|
29 |
+
"ocr_lr": 5e-05,
|
30 |
+
"query_input": "unifont",
|
31 |
+
"resolution": 16,
|
32 |
+
"save_model": 5,
|
33 |
+
"save_model_history": 500,
|
34 |
+
"save_model_path": "saved_models",
|
35 |
+
"seed": 742,
|
36 |
+
"special_alphabet": "\u0391\u03b1\u0392\u03b2\u0393\u03b3\u0394\u03b4\u0395\u03b5\u0396\u03b6\u0397\u03b7\u0398\u03b8\u0399\u03b9\u039a\u03ba\u039b\u03bb\u039c\u03bc\u039d\u03bd\u039e\u03be\u039f\u03bf\u03a0\u03c0\u03a1\u03c1\u03a3\u03c3\u03c2\u03a4\u03c4\u03a5\u03c5\u03a6\u03c6\u03a7\u03c7\u03a8\u03c8\u03a9\u03c9",
|
37 |
+
"tag": "debug",
|
38 |
+
"text_aug_type": "proportional",
|
39 |
+
"text_augment_strength": 0.0,
|
40 |
+
"torch_dtype": "float32",
|
41 |
+
"transformers_version": "4.46.2",
|
42 |
+
"vocab_size": 80,
|
43 |
+
"w_lr": 5e-05,
|
44 |
+
"wandb": false,
|
45 |
+
"writer_loss_weight": 1.0
|
46 |
+
}
|
configuration_vatrpp.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class VATrPPConfig(PretrainedConfig):
|
4 |
+
model_type = "emuru"
|
5 |
+
|
6 |
+
def __init__(self,
|
7 |
+
feat_model_path='files/resnet_18_pretrained.pth',
|
8 |
+
label_encoder='default',
|
9 |
+
save_model_path='saved_models',
|
10 |
+
dataset='IAM',
|
11 |
+
english_words_path='files/english_words.txt',
|
12 |
+
wandb=False,
|
13 |
+
no_writer_loss=False,
|
14 |
+
writer_loss_weight=1.0,
|
15 |
+
no_ocr_loss=False,
|
16 |
+
img_height=32,
|
17 |
+
resolution=16,
|
18 |
+
batch_size=8,
|
19 |
+
num_examples=15,
|
20 |
+
num_writers=339,
|
21 |
+
alphabet='Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%',
|
22 |
+
special_alphabet='ΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω',
|
23 |
+
g_lr=0.00005,
|
24 |
+
d_lr=0.00001,
|
25 |
+
w_lr=0.00005,
|
26 |
+
ocr_lr=0.00005,
|
27 |
+
epochs=100000,
|
28 |
+
num_workers=0,
|
29 |
+
seed=742,
|
30 |
+
num_words=3,
|
31 |
+
is_cycle=False,
|
32 |
+
add_noise=False,
|
33 |
+
save_model=5,
|
34 |
+
save_model_history=500,
|
35 |
+
tag='debug',
|
36 |
+
device='cuda',
|
37 |
+
query_input='unifont',
|
38 |
+
corpus="standard",
|
39 |
+
text_augment_strength=0.0,
|
40 |
+
text_aug_type="proportional",
|
41 |
+
file_suffix=None,
|
42 |
+
augment_ocr=False,
|
43 |
+
d_crop_size=None,
|
44 |
+
**kwargs):
|
45 |
+
super().__init__(**kwargs)
|
46 |
+
self.feat_model_path = feat_model_path
|
47 |
+
self.label_encoder = label_encoder
|
48 |
+
self.save_model_path = save_model_path
|
49 |
+
self.dataset = dataset
|
50 |
+
self.english_words_path = english_words_path
|
51 |
+
self.wandb = wandb
|
52 |
+
self.no_writer_loss = no_writer_loss
|
53 |
+
self.writer_loss_weight = writer_loss_weight
|
54 |
+
self.no_ocr_loss = no_ocr_loss
|
55 |
+
self.img_height = img_height
|
56 |
+
self.resolution = resolution
|
57 |
+
self.batch_size = batch_size
|
58 |
+
self.num_examples = num_examples
|
59 |
+
self.num_writers = num_writers
|
60 |
+
self.alphabet = alphabet
|
61 |
+
self.special_alphabet = special_alphabet
|
62 |
+
self.g_lr = g_lr
|
63 |
+
self.d_lr = d_lr
|
64 |
+
self.w_lr = w_lr
|
65 |
+
self.ocr_lr = ocr_lr
|
66 |
+
self.epochs = epochs
|
67 |
+
self.num_workers = num_workers
|
68 |
+
self.seed = seed
|
69 |
+
self.num_words = num_words
|
70 |
+
self.is_cycle = is_cycle
|
71 |
+
self.add_noise = add_noise
|
72 |
+
self.save_model = save_model
|
73 |
+
self.save_model_history = save_model_history
|
74 |
+
self.tag = tag
|
75 |
+
self.device = device
|
76 |
+
self.query_input = query_input
|
77 |
+
self.corpus = corpus
|
78 |
+
self.text_augment_strength = text_augment_strength
|
79 |
+
self.text_aug_type = text_aug_type
|
80 |
+
self.file_suffix = file_suffix
|
81 |
+
self.augment_ocr = augment_ocr
|
82 |
+
self.d_crop_size = d_crop_size
|
corpora_english/brown-azAZ.tr
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpora_english/in_vocab.subset.tro.37
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accents
|
2 |
+
fifty
|
3 |
+
gross
|
4 |
+
Tea
|
5 |
+
whom
|
6 |
+
renamed
|
7 |
+
Heaven
|
8 |
+
Harry
|
9 |
+
arrange
|
10 |
+
captain
|
11 |
+
why
|
12 |
+
Father
|
13 |
+
beaten
|
14 |
+
Bar
|
15 |
+
base
|
16 |
+
creamy
|
17 |
+
About
|
18 |
+
Allies
|
19 |
+
sound
|
20 |
+
farmers
|
21 |
+
anyone
|
22 |
+
steel
|
23 |
+
Mary
|
24 |
+
used
|
25 |
+
fever
|
26 |
+
looking
|
27 |
+
lately
|
28 |
+
returns
|
29 |
+
humans
|
30 |
+
finals
|
31 |
+
beyond
|
32 |
+
lots
|
33 |
+
waiting
|
34 |
+
cited
|
35 |
+
measure
|
36 |
+
posse
|
37 |
+
blow
|
38 |
+
blonde
|
39 |
+
twice
|
40 |
+
Having
|
41 |
+
compels
|
42 |
+
rooms
|
43 |
+
cocked
|
44 |
+
virtual
|
45 |
+
dying
|
46 |
+
tons
|
47 |
+
Travel
|
48 |
+
idea
|
49 |
+
gripped
|
50 |
+
Act
|
51 |
+
reign
|
52 |
+
moods
|
53 |
+
altered
|
54 |
+
sample
|
55 |
+
Soviet
|
56 |
+
thick
|
57 |
+
enigma
|
58 |
+
here
|
59 |
+
egghead
|
60 |
+
Public
|
61 |
+
Bryan
|
62 |
+
porous
|
63 |
+
estate
|
64 |
+
guilty
|
65 |
+
Caught
|
66 |
+
Lucas
|
67 |
+
observe
|
68 |
+
mouth
|
69 |
+
pricked
|
70 |
+
obscure
|
71 |
+
casual
|
72 |
+
take
|
73 |
+
home
|
74 |
+
amber
|
75 |
+
weekend
|
76 |
+
forming
|
77 |
+
aid
|
78 |
+
outlook
|
79 |
+
uniting
|
80 |
+
But
|
81 |
+
earnest
|
82 |
+
bear
|
83 |
+
news
|
84 |
+
sparked
|
85 |
+
merrily
|
86 |
+
extreme
|
87 |
+
North
|
88 |
+
damned
|
89 |
+
big
|
90 |
+
bosses
|
91 |
+
context
|
92 |
+
easily
|
93 |
+
took
|
94 |
+
hurried
|
95 |
+
Gene
|
96 |
+
due
|
97 |
+
deserve
|
98 |
+
cult
|
99 |
+
leisure
|
100 |
+
critics
|
101 |
+
parish
|
102 |
+
Music
|
103 |
+
charge
|
104 |
+
grey
|
105 |
+
Privy
|
106 |
+
Fred
|
107 |
+
massive
|
108 |
+
others
|
109 |
+
shirt
|
110 |
+
average
|
111 |
+
warning
|
112 |
+
Tuesday
|
113 |
+
locked
|
114 |
+
possess
|
corpora_english/oov.common_words
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
planets
|
2 |
+
lips
|
3 |
+
varies
|
4 |
+
impact
|
5 |
+
skips
|
6 |
+
Gold
|
7 |
+
maple
|
8 |
+
voyager
|
9 |
+
noisy
|
10 |
+
stick
|
11 |
+
forums
|
12 |
+
drafts
|
13 |
+
crimson
|
14 |
+
sever
|
15 |
+
rackets
|
16 |
+
sexy
|
17 |
+
humming
|
18 |
+
cheated
|
19 |
+
lick
|
20 |
+
grades
|
21 |
+
heroic
|
22 |
+
Clever
|
23 |
+
foul
|
24 |
+
mood
|
25 |
+
warrior
|
26 |
+
Morning
|
27 |
+
poetic
|
28 |
+
nodding
|
29 |
+
certify
|
30 |
+
reviews
|
31 |
+
mosaics
|
32 |
+
senders
|
33 |
+
Isle
|
34 |
+
Lied
|
35 |
+
sand
|
36 |
+
Weight
|
37 |
+
writer
|
38 |
+
trusts
|
39 |
+
slot
|
40 |
+
eaten
|
41 |
+
squares
|
42 |
+
lists
|
43 |
+
vary
|
44 |
+
witches
|
45 |
+
compose
|
46 |
+
demons
|
47 |
+
therapy
|
48 |
+
focus
|
49 |
+
sticks
|
50 |
+
Whose
|
51 |
+
bumped
|
52 |
+
visibly
|
53 |
+
redeem
|
54 |
+
arsenal
|
55 |
+
lunatic
|
56 |
+
Similar
|
57 |
+
Bug
|
58 |
+
adheres
|
59 |
+
trail
|
60 |
+
robbing
|
61 |
+
Whisky
|
62 |
+
super
|
63 |
+
screwed
|
64 |
+
Flower
|
65 |
+
salads
|
66 |
+
Glow
|
67 |
+
Vapor
|
68 |
+
Married
|
69 |
+
recieve
|
70 |
+
handle
|
71 |
+
push
|
72 |
+
card
|
73 |
+
skiing
|
74 |
+
lotus
|
75 |
+
cloud
|
76 |
+
windy
|
77 |
+
monkey
|
78 |
+
virus
|
79 |
+
thunder
|
corpora_english/oov_words.txt
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
planets
|
2 |
+
lips
|
3 |
+
varies
|
4 |
+
impact
|
5 |
+
skips
|
6 |
+
Gold
|
7 |
+
maple
|
8 |
+
voyager
|
9 |
+
noisy
|
10 |
+
stick
|
11 |
+
forums
|
12 |
+
drafts
|
13 |
+
crimson
|
14 |
+
sever
|
15 |
+
rackets
|
16 |
+
sexy
|
17 |
+
humming
|
18 |
+
cheated
|
19 |
+
lick
|
20 |
+
grades
|
21 |
+
heroic
|
22 |
+
Clever
|
23 |
+
foul
|
24 |
+
mood
|
25 |
+
warrior
|
26 |
+
Morning
|
27 |
+
poetic
|
28 |
+
nodding
|
29 |
+
certify
|
30 |
+
reviews
|
31 |
+
mosaics
|
32 |
+
senders
|
33 |
+
Isle
|
34 |
+
Lied
|
35 |
+
sand
|
36 |
+
Weight
|
37 |
+
writer
|
38 |
+
trusts
|
39 |
+
slot
|
40 |
+
eaten
|
41 |
+
squares
|
42 |
+
lists
|
43 |
+
vary
|
44 |
+
witches
|
45 |
+
compose
|
46 |
+
demons
|
47 |
+
therapy
|
48 |
+
focus
|
49 |
+
sticks
|
50 |
+
Whose
|
51 |
+
bumped
|
52 |
+
visibly
|
53 |
+
redeem
|
54 |
+
arsenal
|
55 |
+
lunatic
|
56 |
+
Similar
|
57 |
+
Bug
|
58 |
+
adheres
|
59 |
+
trail
|
60 |
+
robbing
|
61 |
+
Whisky
|
62 |
+
super
|
63 |
+
screwed
|
64 |
+
Flower
|
65 |
+
salads
|
66 |
+
Glow
|
67 |
+
Vapor
|
68 |
+
Married
|
69 |
+
recieve
|
70 |
+
handle
|
71 |
+
push
|
72 |
+
card
|
73 |
+
skiing
|
74 |
+
lotus
|
75 |
+
cloud
|
76 |
+
windy
|
77 |
+
monkey
|
78 |
+
virus
|
79 |
+
thunder
|
80 |
+
Keegan
|
81 |
+
purling
|
82 |
+
Orpheus
|
83 |
+
Prence
|
84 |
+
Yin
|
85 |
+
Kansas
|
86 |
+
jowls
|
87 |
+
Alabama
|
88 |
+
Szold
|
89 |
+
Chou
|
90 |
+
Orange
|
91 |
+
suspend
|
92 |
+
barred
|
93 |
+
deceit
|
94 |
+
reward
|
95 |
+
soy
|
96 |
+
Vail
|
97 |
+
lad
|
98 |
+
Loesser
|
99 |
+
Hutton
|
100 |
+
jerks
|
101 |
+
yelling
|
102 |
+
Heywood
|
103 |
+
sacker
|
104 |
+
comest
|
105 |
+
tense
|
106 |
+
par
|
107 |
+
fiend
|
108 |
+
Soiree
|
109 |
+
voted
|
110 |
+
Putting
|
111 |
+
pansy
|
112 |
+
doormen
|
113 |
+
mayor
|
114 |
+
Owens
|
115 |
+
noting
|
116 |
+
pauses
|
117 |
+
USP
|
118 |
+
crudely
|
119 |
+
grooved
|
120 |
+
furor
|
121 |
+
ignited
|
122 |
+
kittens
|
123 |
+
broader
|
124 |
+
slang
|
125 |
+
ballets
|
126 |
+
quacked
|
127 |
+
Paulus
|
128 |
+
Castles
|
129 |
+
upswing
|
130 |
+
dabbled
|
131 |
+
Animals
|
132 |
+
Kidder
|
133 |
+
Writers
|
134 |
+
laces
|
135 |
+
bled
|
136 |
+
scoped
|
137 |
+
yield
|
138 |
+
scoured
|
139 |
+
Schenk
|
140 |
+
Wratten
|
141 |
+
Menfolk
|
142 |
+
foamy
|
143 |
+
scratch
|
144 |
+
minced
|
145 |
+
nudged
|
146 |
+
Seats
|
147 |
+
Judging
|
148 |
+
Turbine
|
149 |
+
Strict
|
150 |
+
whined
|
151 |
+
crupper
|
152 |
+
Dussa
|
153 |
+
finned
|
154 |
+
voter
|
155 |
+
Jacobs
|
156 |
+
calmly
|
157 |
+
hip
|
158 |
+
clubs
|
159 |
+
quintet
|
160 |
+
blunts
|
161 |
+
Grazie
|
162 |
+
Barton
|
163 |
+
NAB
|
164 |
+
specie
|
165 |
+
Fonta
|
166 |
+
narrow
|
167 |
+
Swan
|
168 |
+
denials
|
169 |
+
Rawson
|
170 |
+
potato
|
171 |
+
Choral
|
172 |
+
diverse
|
173 |
+
Educate
|
174 |
+
unities
|
175 |
+
Ferry
|
176 |
+
Bonner
|
177 |
+
manuals
|
178 |
+
NAIR
|
179 |
+
imputed
|
180 |
+
initial
|
181 |
+
wallet
|
182 |
+
Sesame
|
183 |
+
maroon
|
184 |
+
Related
|
185 |
+
Quiney
|
186 |
+
Monster
|
187 |
+
brainy
|
188 |
+
Nolan
|
189 |
+
Thrifty
|
190 |
+
Tel
|
191 |
+
Ye
|
192 |
+
Sumter
|
193 |
+
Bonnet
|
194 |
+
sheepe
|
195 |
+
nagged
|
196 |
+
ribbing
|
197 |
+
hunt
|
198 |
+
AA
|
199 |
+
Pohly
|
200 |
+
triol
|
201 |
+
saws
|
202 |
+
popped
|
203 |
+
aloof
|
204 |
+
Ceramic
|
205 |
+
thong
|
206 |
+
typed
|
207 |
+
broadly
|
208 |
+
Figures
|
209 |
+
riddle
|
210 |
+
Otis
|
211 |
+
Sainted
|
212 |
+
upbeat
|
213 |
+
Getting
|
214 |
+
hisself
|
215 |
+
junta
|
216 |
+
Labans
|
217 |
+
starter
|
218 |
+
coward
|
219 |
+
Anthea
|
220 |
+
hurlers
|
221 |
+
Dervish
|
222 |
+
Turin
|
223 |
+
oud
|
224 |
+
tyranny
|
225 |
+
Rotary
|
226 |
+
Veneto
|
227 |
+
pulls
|
228 |
+
bowl
|
229 |
+
utopias
|
230 |
+
auburn
|
231 |
+
osmotic
|
232 |
+
myrtle
|
233 |
+
furrow
|
234 |
+
laws
|
235 |
+
Uh
|
236 |
+
Hodges
|
237 |
+
Wilde
|
238 |
+
Neck
|
239 |
+
snaked
|
240 |
+
decorum
|
241 |
+
edema
|
242 |
+
Dunston
|
243 |
+
clinics
|
244 |
+
Abide
|
245 |
+
Dover
|
246 |
+
voltaic
|
247 |
+
Modern
|
248 |
+
Farr
|
249 |
+
thaw
|
250 |
+
moi
|
251 |
+
leaning
|
252 |
+
wedlock
|
253 |
+
Carson
|
254 |
+
star
|
255 |
+
Hymn
|
256 |
+
Stack
|
257 |
+
genes
|
258 |
+
Shayne
|
259 |
+
Moune
|
260 |
+
slipped
|
261 |
+
legatee
|
262 |
+
coerced
|
263 |
+
Gates
|
264 |
+
pulse
|
265 |
+
Granny
|
266 |
+
bat
|
267 |
+
Fruit
|
268 |
+
Cadesi
|
269 |
+
Tee
|
270 |
+
Dreiser
|
271 |
+
Getz
|
272 |
+
Ways
|
273 |
+
cogs
|
274 |
+
hydrous
|
275 |
+
sweep
|
276 |
+
quarrel
|
277 |
+
mobcaps
|
278 |
+
slash
|
279 |
+
throats
|
280 |
+
Royaux
|
281 |
+
cafes
|
282 |
+
crusher
|
283 |
+
rusted
|
284 |
+
Eskimo
|
285 |
+
slatted
|
286 |
+
pallet
|
287 |
+
yelps
|
288 |
+
slanted
|
289 |
+
confide
|
290 |
+
Gomez
|
291 |
+
untidy
|
292 |
+
Sigmund
|
293 |
+
Marine
|
294 |
+
roll
|
295 |
+
NRL
|
296 |
+
Dukes
|
297 |
+
tumours
|
298 |
+
LP
|
299 |
+
turtles
|
300 |
+
audible
|
301 |
+
Woodrow
|
302 |
+
retreat
|
303 |
+
Orders
|
304 |
+
Conlow
|
305 |
+
hobby
|
306 |
+
skin
|
307 |
+
tally
|
308 |
+
frosted
|
309 |
+
drowned
|
310 |
+
wedged
|
311 |
+
queen
|
312 |
+
poised
|
313 |
+
eluded
|
314 |
+
Letter
|
315 |
+
ticking
|
316 |
+
kill
|
317 |
+
rancor
|
318 |
+
Plant
|
319 |
+
Brandel
|
320 |
+
Willows
|
321 |
+
riddles
|
322 |
+
carven
|
323 |
+
Spiller
|
324 |
+
yen
|
325 |
+
jerky
|
326 |
+
tenure
|
327 |
+
daubed
|
328 |
+
Serves
|
329 |
+
pimpled
|
330 |
+
ACTH
|
331 |
+
ruh
|
332 |
+
afield
|
333 |
+
suffuse
|
334 |
+
muffins
|
335 |
+
Miners
|
336 |
+
Cabrini
|
337 |
+
weakly
|
338 |
+
upriver
|
339 |
+
Newsom
|
340 |
+
Meeker
|
341 |
+
weed
|
342 |
+
fiscal
|
343 |
+
Diane
|
344 |
+
Errors
|
345 |
+
Mig
|
346 |
+
biz
|
347 |
+
Drink
|
348 |
+
chop
|
349 |
+
Bumbry
|
350 |
+
Babin
|
351 |
+
optimum
|
352 |
+
Leyden
|
353 |
+
enrage
|
354 |
+
induces
|
355 |
+
newel
|
356 |
+
trim
|
357 |
+
bolts
|
358 |
+
frog
|
359 |
+
cinder
|
360 |
+
Lo
|
361 |
+
clobber
|
362 |
+
Mennen
|
363 |
+
Othon
|
364 |
+
Ocean
|
365 |
+
jerking
|
366 |
+
engine
|
367 |
+
Belasco
|
368 |
+
hero
|
369 |
+
flora
|
370 |
+
Injuns
|
371 |
+
Rico
|
372 |
+
Gary
|
373 |
+
snake
|
374 |
+
hating
|
375 |
+
Suggs
|
376 |
+
booze
|
377 |
+
Lescaut
|
378 |
+
Molard
|
379 |
+
startle
|
380 |
+
Aggie
|
381 |
+
lengthy
|
382 |
+
Shoals
|
383 |
+
ideals
|
384 |
+
Zen
|
385 |
+
stem
|
386 |
+
noon
|
387 |
+
hoes
|
388 |
+
Seafood
|
389 |
+
yuh
|
390 |
+
Mostly
|
391 |
+
seeds
|
392 |
+
bestow
|
393 |
+
acetate
|
394 |
+
jokers
|
395 |
+
waning
|
396 |
+
volumes
|
397 |
+
ein
|
398 |
+
Rich
|
399 |
+
Galt
|
400 |
+
pasted
|
create_style_sample.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
from util.vision import get_page, get_words
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
|
11 |
+
parser.add_argument("--input-image", type=str, required=True)
|
12 |
+
parser.add_argument("--output-folder", type=str, required=True, default='files/style_samples/00')
|
13 |
+
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
image = cv2.imread(args.input_image)
|
17 |
+
image = cv2.resize(image, (image.shape[1], image.shape[0]))
|
18 |
+
result = get_page(image)
|
19 |
+
words, _ = get_words(result)
|
20 |
+
|
21 |
+
output_path = args.output_folder
|
22 |
+
if not os.path.exists(output_path):
|
23 |
+
os.mkdir(output_path)
|
24 |
+
for i, word in enumerate(words):
|
25 |
+
cv2.imwrite(os.path.join(output_path, f"word{i}.png"), word)
|
data/create_data.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import PIL
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
TO_MERGE = {
|
15 |
+
'.': 'left',
|
16 |
+
',': 'left',
|
17 |
+
'!': 'left',
|
18 |
+
'?': 'left',
|
19 |
+
'(': 'right',
|
20 |
+
')': 'left',
|
21 |
+
'\"': 'random',
|
22 |
+
"\'": 'random',
|
23 |
+
":": 'left',
|
24 |
+
";": 'left',
|
25 |
+
"-": 'random'
|
26 |
+
}
|
27 |
+
|
28 |
+
FILTER_ERR = False
|
29 |
+
|
30 |
+
|
31 |
+
def resize(image, size):
|
32 |
+
image_pil = Image.fromarray(image.astype('uint8'), 'L')
|
33 |
+
image_pil = image_pil.resize(size)
|
34 |
+
return np.array(image_pil)
|
35 |
+
|
36 |
+
|
37 |
+
def get_author_ids(base_folder: str):
|
38 |
+
with open(os.path.join(base_folder, "gan.iam.tr_va.gt.filter27"), 'r') as f:
|
39 |
+
training_authors = [line.split(",")[0] for line in f]
|
40 |
+
training_authors = set(training_authors)
|
41 |
+
|
42 |
+
with open(os.path.join(base_folder, "gan.iam.test.gt.filter27"), 'r') as f:
|
43 |
+
test_authors = [line.split(",")[0] for line in f]
|
44 |
+
test_authors = set(test_authors)
|
45 |
+
|
46 |
+
assert len(training_authors.intersection(test_authors)) == 0
|
47 |
+
|
48 |
+
return training_authors, test_authors
|
49 |
+
|
50 |
+
|
51 |
+
class IAMImage:
|
52 |
+
def __init__(self, image: np.array, label: str, image_id: int, line_id: str, bbox: list = None, iam_image_id: str = None):
|
53 |
+
self.image = image
|
54 |
+
self.label = label
|
55 |
+
self.image_id = image_id
|
56 |
+
self.line_id = line_id
|
57 |
+
self.iam_image_id = iam_image_id
|
58 |
+
self.has_bbox = False
|
59 |
+
if bbox is not None:
|
60 |
+
self.has_bbox = True
|
61 |
+
self.x, self.y, self.w, self.h = bbox
|
62 |
+
|
63 |
+
def merge(self, other: 'IAMImage'):
|
64 |
+
global MERGER_COUNT
|
65 |
+
assert self.has_bbox, "IAM image has no bounding box information"
|
66 |
+
y = min(self.y, other.y)
|
67 |
+
h = max(other.y + other.h, self.y + self.h) - y
|
68 |
+
|
69 |
+
x = min(self.x, other.x)
|
70 |
+
w = max(self.x + self.w, other.x + other.w) - x
|
71 |
+
|
72 |
+
new_image = np.ones((h, w), dtype=self.image.dtype) * 255
|
73 |
+
|
74 |
+
anchor_x = self.x - x
|
75 |
+
anchor_y = self.y - y
|
76 |
+
new_image[anchor_y:anchor_y + self.h, anchor_x:anchor_x + self.w] = self.image
|
77 |
+
|
78 |
+
anchor_x = other.x - x
|
79 |
+
anchor_y = other.y - y
|
80 |
+
new_image[anchor_y:anchor_y + other.h, anchor_x:anchor_x + other.w] = other.image
|
81 |
+
|
82 |
+
if other.x - (self.x + self.w) > 50:
|
83 |
+
new_label = self.label + " " + other.label
|
84 |
+
else:
|
85 |
+
new_label = self.label + other.label
|
86 |
+
new_id = self.image_id
|
87 |
+
new_bbox = [x, y, w, h]
|
88 |
+
|
89 |
+
new_iam_image_id = self.iam_image_id if len(self.label) > len(other.label) else other.iam_image_id
|
90 |
+
return IAMImage(new_image, new_label, new_id, self.line_id, new_bbox, iam_image_id=new_iam_image_id)
|
91 |
+
|
92 |
+
|
93 |
+
def read_iam_lines(base_folder: str) -> dict:
|
94 |
+
form_to_author = {}
|
95 |
+
with open(os.path.join(base_folder, "forms.txt"), 'r') as f:
|
96 |
+
for line in f:
|
97 |
+
if not line.startswith("#"):
|
98 |
+
form, author, *_ = line.split(" ")
|
99 |
+
form_to_author[form] = author
|
100 |
+
|
101 |
+
training_authors, test_authors = get_author_ids(base_folder)
|
102 |
+
|
103 |
+
dataset_dict = {
|
104 |
+
'train': defaultdict(list),
|
105 |
+
'test': defaultdict(list),
|
106 |
+
'other': defaultdict(list)
|
107 |
+
}
|
108 |
+
|
109 |
+
image_count = 0
|
110 |
+
|
111 |
+
with open(os.path.join(base_folder, "sentences.txt"), 'r') as f:
|
112 |
+
for line in f:
|
113 |
+
if not line.startswith("#"):
|
114 |
+
line_id, _, ok, *_, label = line.rstrip().split(" ")
|
115 |
+
form_id = "-".join(line_id.split("-")[:2])
|
116 |
+
author_id = form_to_author[form_id]
|
117 |
+
|
118 |
+
if ok != 'ok' and FILTER_ERR:
|
119 |
+
continue
|
120 |
+
|
121 |
+
line_label = ""
|
122 |
+
for word in label.split("|"):
|
123 |
+
if not(len(line_label) == 0 or word in [".", ","]):
|
124 |
+
line_label += " "
|
125 |
+
line_label += word
|
126 |
+
|
127 |
+
image_path = os.path.join(base_folder, "sentences", form_id.split("-")[0], form_id, f"{line_id}.png")
|
128 |
+
|
129 |
+
subset = 'other'
|
130 |
+
if author_id in training_authors:
|
131 |
+
subset = 'train'
|
132 |
+
elif author_id in test_authors:
|
133 |
+
subset = 'test'
|
134 |
+
|
135 |
+
im = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
136 |
+
if im is not None and im.size > 1:
|
137 |
+
dataset_dict[subset][author_id].append(IAMImage(
|
138 |
+
im, line_label, image_count, line_id, None
|
139 |
+
))
|
140 |
+
image_count += 1
|
141 |
+
|
142 |
+
return dataset_dict
|
143 |
+
|
144 |
+
|
145 |
+
def read_iam(base_folder: str) -> dict:
|
146 |
+
with open(os.path.join(base_folder, "forms.txt"), 'r') as f:
|
147 |
+
forms = [line.rstrip() for line in f if not line.startswith("#")]
|
148 |
+
|
149 |
+
training_authors, test_authors = get_author_ids(base_folder)
|
150 |
+
|
151 |
+
image_info = {}
|
152 |
+
with open(os.path.join(base_folder, "words.txt"), 'r') as f:
|
153 |
+
for line in f:
|
154 |
+
if not line.startswith("#"):
|
155 |
+
image_id, ok, threshold, x, y, w, h, tag, *content = line.rstrip().split(" ")
|
156 |
+
image_info[image_id] = {
|
157 |
+
'ok': ok == 'ok',
|
158 |
+
'threshold': threshold,
|
159 |
+
'content': " ".join(content) if isinstance(content, list) else content,
|
160 |
+
'bbox': [int(x), int(y), int(w), int(h)]
|
161 |
+
}
|
162 |
+
|
163 |
+
dataset_dict = {
|
164 |
+
'train': defaultdict(list),
|
165 |
+
'test': defaultdict(list),
|
166 |
+
'other': defaultdict(list)
|
167 |
+
}
|
168 |
+
|
169 |
+
image_count = 0
|
170 |
+
err_count = 0
|
171 |
+
|
172 |
+
for form in forms:
|
173 |
+
form_id, writer_id, *_ = form.split(" ")
|
174 |
+
base_form = form_id.split("-")[0]
|
175 |
+
|
176 |
+
form_path = os.path.join(base_folder, "words", base_form, form_id)
|
177 |
+
|
178 |
+
for image_name in os.listdir(form_path):
|
179 |
+
image_id = image_name.split(".")[0]
|
180 |
+
info = image_info[image_id]
|
181 |
+
|
182 |
+
subset = 'other'
|
183 |
+
if writer_id in training_authors:
|
184 |
+
subset = 'train'
|
185 |
+
elif writer_id in test_authors:
|
186 |
+
subset = 'test'
|
187 |
+
|
188 |
+
if info['ok'] or not FILTER_ERR:
|
189 |
+
im = cv2.imread(os.path.join(form_path, image_name), cv2.IMREAD_GRAYSCALE)
|
190 |
+
if not info['ok'] and False:
|
191 |
+
cv2.destroyAllWindows()
|
192 |
+
print(info['content'])
|
193 |
+
cv2.imshow("image", im)
|
194 |
+
cv2.waitKey(0)
|
195 |
+
|
196 |
+
if im is not None and im.size > 1:
|
197 |
+
dataset_dict[subset][writer_id].append(IAMImage(
|
198 |
+
im, info['content'], image_count, "-".join(image_id.split("-")[:3]), info['bbox'], iam_image_id=image_id
|
199 |
+
))
|
200 |
+
image_count += 1
|
201 |
+
else:
|
202 |
+
err_count += 1
|
203 |
+
print(f"Could not read image {image_name}, skipping")
|
204 |
+
else:
|
205 |
+
err_count += 1
|
206 |
+
|
207 |
+
assert not dataset_dict['train'].keys() & dataset_dict['test'].keys(), "Training and Testing set have common authors"
|
208 |
+
|
209 |
+
print(f"Skipped images: {err_count}")
|
210 |
+
|
211 |
+
return dataset_dict
|
212 |
+
|
213 |
+
|
214 |
+
def read_cvl_set(set_folder: str):
|
215 |
+
set_images = defaultdict(list)
|
216 |
+
words_path = os.path.join(set_folder, "words")
|
217 |
+
|
218 |
+
image_id = 0
|
219 |
+
|
220 |
+
for author_id in os.listdir(words_path):
|
221 |
+
author_path = os.path.join(words_path, author_id)
|
222 |
+
|
223 |
+
for image_file in os.listdir(author_path):
|
224 |
+
label = image_file.split("-")[-1].split(".")[0]
|
225 |
+
line_id = "-".join(image_file.split("-")[:-2])
|
226 |
+
|
227 |
+
stream = open(os.path.join(author_path, image_file), "rb")
|
228 |
+
bytes = bytearray(stream.read())
|
229 |
+
numpyarray = np.asarray(bytes, dtype=np.uint8)
|
230 |
+
image = cv2.imdecode(numpyarray, cv2.IMREAD_UNCHANGED)
|
231 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
232 |
+
if image is not None and image.size > 1:
|
233 |
+
set_images[int(author_id)].append(IAMImage(image, label, image_id, line_id))
|
234 |
+
image_id += 1
|
235 |
+
|
236 |
+
return set_images
|
237 |
+
|
238 |
+
|
239 |
+
def read_cvl(base_folder: str):
|
240 |
+
dataset_dict = {
|
241 |
+
'test': read_cvl_set(os.path.join(base_folder, 'testset')),
|
242 |
+
'train': read_cvl_set(os.path.join(base_folder, 'trainset'))
|
243 |
+
}
|
244 |
+
|
245 |
+
assert not dataset_dict['train'].keys() & dataset_dict[
|
246 |
+
'test'].keys(), "Training and Testing set have common authors"
|
247 |
+
|
248 |
+
return dataset_dict
|
249 |
+
|
250 |
+
def pad_top(image: np.array, height: int) -> np.array:
|
251 |
+
result = np.ones((height, image.shape[1]), dtype=np.uint8) * 255
|
252 |
+
result[height - image.shape[0]:, :image.shape[1]] = image
|
253 |
+
|
254 |
+
return result
|
255 |
+
|
256 |
+
|
257 |
+
def scale_per_writer(writer_dict: dict, target_height: int, char_width: int = None) -> dict:
|
258 |
+
for author_id in writer_dict.keys():
|
259 |
+
max_height = max([image_dict.image.shape[0] for image_dict in writer_dict[author_id]])
|
260 |
+
scale_y = target_height / max_height
|
261 |
+
|
262 |
+
for image_dict in writer_dict[author_id]:
|
263 |
+
image = image_dict.image
|
264 |
+
scale_x = scale_y if char_width is None else len(image_dict.label) * char_width / image_dict.image.shape[1]
|
265 |
+
#image = cv2.resize(image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC)
|
266 |
+
image = resize(image, (int(image.shape[1] * scale_x), int(image.shape[0] * scale_y)))
|
267 |
+
image_dict.image = pad_top(image, target_height)
|
268 |
+
|
269 |
+
return writer_dict
|
270 |
+
|
271 |
+
|
272 |
+
def scale_images(writer_dict: dict, target_height: int, char_width: int = None) -> dict:
|
273 |
+
for author_id in writer_dict.keys():
|
274 |
+
for image_dict in writer_dict[author_id]:
|
275 |
+
scale_y = target_height / image_dict.image.shape[0]
|
276 |
+
scale_x = scale_y if char_width is None else len(image_dict.label) * char_width / image_dict.image.shape[1]
|
277 |
+
#image_dict.image = cv2.resize(image_dict.image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC)
|
278 |
+
image_dict.image = resize(image_dict.image, (int(image_dict.image.shape[1] * scale_x), target_height))
|
279 |
+
return writer_dict
|
280 |
+
|
281 |
+
|
282 |
+
def scale_word_width(writer_dict: dict):
|
283 |
+
for author_id in writer_dict.keys():
|
284 |
+
for image_dict in writer_dict[author_id]:
|
285 |
+
width = len(image_dict.label) * (image_dict.image.shape[0] / 2.0)
|
286 |
+
image_dict.image = resize(image_dict.image, (int(width), image_dict.image.shape[0]))
|
287 |
+
return writer_dict
|
288 |
+
|
289 |
+
|
290 |
+
def get_sentences(author_dict: dict):
|
291 |
+
collected = defaultdict(list)
|
292 |
+
for image in author_dict:
|
293 |
+
collected[image.line_id].append(image)
|
294 |
+
|
295 |
+
return [v for k, v in collected.items()]
|
296 |
+
|
297 |
+
|
298 |
+
def merge_author_words(author_words):
|
299 |
+
def try_left_merge(index: int):
|
300 |
+
if index > 0 and author_words[index - 1].line_id == author_words[index].line_id and not to_remove[index - 1] and not author_words[index - 1].label in TO_MERGE.keys():
|
301 |
+
merged = author_words[index - 1].merge(author_words[index])
|
302 |
+
author_words[index - 1] = merged
|
303 |
+
to_remove[index] = True
|
304 |
+
return True
|
305 |
+
return False
|
306 |
+
|
307 |
+
def try_right_merge(index: int):
|
308 |
+
if index < len(author_words) - 1 and author_words[index].line_id == author_words[index + 1].line_id and not to_remove[index + 1] and not author_words[index + 1].label in TO_MERGE.keys():
|
309 |
+
merged = iam_image.merge(author_words[index + 1])
|
310 |
+
author_words[index + 1] = merged
|
311 |
+
to_remove[index] = True
|
312 |
+
return True
|
313 |
+
return False
|
314 |
+
|
315 |
+
to_remove = [False for _ in range(len(author_words))]
|
316 |
+
for i in range(len(author_words)):
|
317 |
+
iam_image = author_words[i]
|
318 |
+
if iam_image.label in TO_MERGE.keys():
|
319 |
+
merge_type = TO_MERGE[iam_image.label] if TO_MERGE[iam_image.label] != 'random' else random.choice(['left', 'right'])
|
320 |
+
if merge_type == 'left':
|
321 |
+
if not try_left_merge(i):
|
322 |
+
if not try_right_merge(i):
|
323 |
+
print(f"Could not merge char: {iam_image.label}")
|
324 |
+
else:
|
325 |
+
if not try_right_merge(i):
|
326 |
+
if not try_left_merge(i):
|
327 |
+
print(f"Could not merge char: {iam_image.label}")
|
328 |
+
|
329 |
+
return [image for image, remove in zip(author_words, to_remove) if not remove], sum(to_remove)
|
330 |
+
|
331 |
+
|
332 |
+
def merge_punctuation(writer_dict: dict) -> dict:
|
333 |
+
for author_id in writer_dict.keys():
|
334 |
+
author_dict = writer_dict[author_id]
|
335 |
+
|
336 |
+
merged = 1
|
337 |
+
while merged > 0:
|
338 |
+
author_dict, merged = merge_author_words(author_dict)
|
339 |
+
|
340 |
+
writer_dict[author_id] = author_dict
|
341 |
+
|
342 |
+
return writer_dict
|
343 |
+
|
344 |
+
|
345 |
+
def filter_punctuation(writer_dict: dict) -> dict:
|
346 |
+
for author_id in writer_dict.keys():
|
347 |
+
author_list = [im for im in writer_dict[author_id] if im.label not in TO_MERGE.keys()]
|
348 |
+
|
349 |
+
writer_dict[author_id] = author_list
|
350 |
+
|
351 |
+
return writer_dict
|
352 |
+
|
353 |
+
|
354 |
+
def filter_by_width(writer_dict: dict, target_height: int = 32, min_width: int = 16, max_width: int = 17) -> dict:
|
355 |
+
def is_valid(iam_image: IAMImage) -> bool:
|
356 |
+
target_width = (target_height / iam_image.image.shape[0]) * iam_image.image.shape[1]
|
357 |
+
if len(iam_image.label) * min_width / 3 <= target_width <= len(iam_image.label) * max_width * 3:
|
358 |
+
return True
|
359 |
+
else:
|
360 |
+
return False
|
361 |
+
|
362 |
+
for author_id in writer_dict.keys():
|
363 |
+
author_list = [im for im in writer_dict[author_id] if is_valid(im)]
|
364 |
+
|
365 |
+
writer_dict[author_id] = author_list
|
366 |
+
|
367 |
+
return writer_dict
|
368 |
+
|
369 |
+
|
370 |
+
def write_data(dataset_dict: dict, location: str, height, punct_mode: str = 'none', author_scale: bool = False, uniform_char_width: bool = False):
|
371 |
+
assert punct_mode in ['none', 'filter', 'merge']
|
372 |
+
result = {}
|
373 |
+
for key in dataset_dict.keys():
|
374 |
+
result[key] = {}
|
375 |
+
|
376 |
+
subset_dict = dataset_dict[key]
|
377 |
+
|
378 |
+
subset_dict = filter_by_width(subset_dict)
|
379 |
+
|
380 |
+
if punct_mode == 'merge':
|
381 |
+
subset_dict = merge_punctuation(subset_dict)
|
382 |
+
elif punct_mode == 'filter':
|
383 |
+
subset_dict = filter_punctuation(subset_dict)
|
384 |
+
|
385 |
+
char_width = 16 if uniform_char_width else None
|
386 |
+
|
387 |
+
if author_scale:
|
388 |
+
subset_dict = scale_per_writer(subset_dict, height, char_width)
|
389 |
+
else:
|
390 |
+
subset_dict = scale_images(subset_dict, height, char_width)
|
391 |
+
|
392 |
+
for author_id in subset_dict:
|
393 |
+
author_images = []
|
394 |
+
for image_dict in subset_dict[author_id]:
|
395 |
+
author_images.append({
|
396 |
+
'img': PIL.Image.fromarray(image_dict.image),
|
397 |
+
'label': image_dict.label,
|
398 |
+
'image_id': image_dict.image_id,
|
399 |
+
'original_image_id': image_dict.iam_image_id
|
400 |
+
})
|
401 |
+
result[key][author_id] = author_images
|
402 |
+
|
403 |
+
with open(location, 'wb') as f:
|
404 |
+
pickle.dump(result, f)
|
405 |
+
|
406 |
+
|
407 |
+
def write_fid(dataset_dict: dict, location: str):
|
408 |
+
data = dataset_dict['test']
|
409 |
+
data = scale_images(data, 64, None)
|
410 |
+
for author in data.keys():
|
411 |
+
author_folder = os.path.join(location, author)
|
412 |
+
os.mkdir(author_folder)
|
413 |
+
count = 0
|
414 |
+
for image in data[author]:
|
415 |
+
img = image.image
|
416 |
+
cv2.imwrite(os.path.join(author_folder, f"{count}.png"), img.squeeze().astype(np.uint8))
|
417 |
+
count += 1
|
418 |
+
|
419 |
+
|
420 |
+
def write_images_per_author(dataset_dict: dict, output_file: str):
|
421 |
+
data = dataset_dict["test"]
|
422 |
+
|
423 |
+
result = {}
|
424 |
+
|
425 |
+
for author in data.keys():
|
426 |
+
author_images = [image.iam_image_id for image in data[author]]
|
427 |
+
result[author] = author_images
|
428 |
+
|
429 |
+
with open(output_file, 'w') as f:
|
430 |
+
json.dump(result, f)
|
431 |
+
|
432 |
+
|
433 |
+
def write_words(dataset_dict: dict, output_file):
|
434 |
+
data = dataset_dict['train']
|
435 |
+
|
436 |
+
all_words = []
|
437 |
+
|
438 |
+
for author in data.keys():
|
439 |
+
all_words.extend([image.label for image in data[author]])
|
440 |
+
|
441 |
+
with open(output_file, 'w') as f:
|
442 |
+
for word in all_words:
|
443 |
+
f.write(f"{word}\n")
|
444 |
+
|
445 |
+
|
446 |
+
if __name__ == "__main__":
|
447 |
+
data_path = r"D:\Datasets\IAM"
|
448 |
+
fid_location = r"E:/projects/evaluation/shtg_interface/data/reference_imgs/h64/iam"
|
449 |
+
height = 32
|
450 |
+
data_collection = {}
|
451 |
+
|
452 |
+
output_location = r"E:\projects\evaluation\shtg_interface\data\datasets"
|
453 |
+
|
454 |
+
data = read_iam(data_path)
|
455 |
+
test_data = dict(scale_word_width(data['test']))
|
456 |
+
train_data = dict(scale_word_width(data['train']))
|
457 |
+
test_data.update(train_data)
|
458 |
+
for key, value in test_data.items():
|
459 |
+
for image_object in value:
|
460 |
+
if len(image_object.label) <= 0 or image_object.image.size == 0:
|
461 |
+
continue
|
462 |
+
data_collection[image_object.iam_image_id] = {
|
463 |
+
'img': image_object.image,
|
464 |
+
'lbl': image_object.label,
|
465 |
+
'author_id': key
|
466 |
+
}
|
467 |
+
|
468 |
+
with gzip.open(os.path.join(output_location, f"iam_w16_words_data.pkl.gz"), 'wb') as f:
|
469 |
+
pickle.dump(data_collection, f)
|
data/dataset.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import os
|
9 |
+
import pickle
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
|
15 |
+
def get_dataset_path(dataset_name, height, file_suffix, datasets_path):
|
16 |
+
if file_suffix is not None:
|
17 |
+
filename = f'{dataset_name}-{height}-{file_suffix}.pickle'
|
18 |
+
else:
|
19 |
+
filename = f'{dataset_name}-{height}.pickle'
|
20 |
+
|
21 |
+
return os.path.join(datasets_path, filename)
|
22 |
+
|
23 |
+
|
24 |
+
def get_transform(grayscale=False, convert=True):
|
25 |
+
transform_list = []
|
26 |
+
if grayscale:
|
27 |
+
transform_list.append(transforms.Grayscale(1))
|
28 |
+
|
29 |
+
if convert:
|
30 |
+
transform_list += [transforms.ToTensor()]
|
31 |
+
if grayscale:
|
32 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
33 |
+
else:
|
34 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
35 |
+
|
36 |
+
return transforms.Compose(transform_list)
|
37 |
+
|
38 |
+
|
39 |
+
class TextDataset:
|
40 |
+
|
41 |
+
def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, min_virtual_size=0, validation=False, debug=False):
|
42 |
+
self.NUM_EXAMPLES = num_examples
|
43 |
+
self.debug = debug
|
44 |
+
self.min_virtual_size = min_virtual_size
|
45 |
+
|
46 |
+
subset = 'test' if validation else 'train'
|
47 |
+
|
48 |
+
# base_path=DATASET_PATHS
|
49 |
+
file_to_store = open(base_path, "rb")
|
50 |
+
self.IMG_DATA = pickle.load(file_to_store)[subset]
|
51 |
+
self.IMG_DATA = dict(list(self.IMG_DATA.items())) # [:NUM_WRITERS])
|
52 |
+
if 'None' in self.IMG_DATA.keys():
|
53 |
+
del self.IMG_DATA['None']
|
54 |
+
|
55 |
+
self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), [])))))
|
56 |
+
self.author_id = list(self.IMG_DATA.keys())
|
57 |
+
|
58 |
+
self.transform = get_transform(grayscale=True)
|
59 |
+
self.target_transform = target_transform
|
60 |
+
|
61 |
+
self.collate_fn = TextCollator(collator_resolution)
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
if self.debug:
|
65 |
+
return 16
|
66 |
+
return max(len(self.author_id), self.min_virtual_size)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def num_writers(self):
|
70 |
+
return len(self.author_id)
|
71 |
+
|
72 |
+
def __getitem__(self, index):
|
73 |
+
index = index % len(self.author_id)
|
74 |
+
|
75 |
+
author_id = self.author_id[index]
|
76 |
+
|
77 |
+
self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id]
|
78 |
+
random_idxs = random.choices([i for i in range(len(self.IMG_DATA_AUTHOR))], k=self.NUM_EXAMPLES)
|
79 |
+
|
80 |
+
word_data = random.choice(self.IMG_DATA_AUTHOR)
|
81 |
+
real_img = self.transform(word_data['img'].convert('L'))
|
82 |
+
real_labels = word_data['label'].encode()
|
83 |
+
|
84 |
+
imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs]
|
85 |
+
slabels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs]
|
86 |
+
|
87 |
+
max_width = 192 # [img.shape[1] for img in imgs]
|
88 |
+
|
89 |
+
imgs_pad = []
|
90 |
+
imgs_wids = []
|
91 |
+
|
92 |
+
for img in imgs:
|
93 |
+
img_height, img_width = img.shape[0], img.shape[1]
|
94 |
+
output_img = np.ones((img_height, max_width), dtype='float32') * 255.0
|
95 |
+
output_img[:, :img_width] = img[:, :max_width]
|
96 |
+
|
97 |
+
imgs_pad.append(self.transform(Image.fromarray(output_img.astype(np.uint8))))
|
98 |
+
imgs_wids.append(img_width)
|
99 |
+
|
100 |
+
imgs_pad = torch.cat(imgs_pad, 0)
|
101 |
+
|
102 |
+
item = {
|
103 |
+
'simg': imgs_pad, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)]
|
104 |
+
'swids': imgs_wids, # widths of the N images [list(N)]
|
105 |
+
'img': real_img, # the input image [1, H (32), W]
|
106 |
+
'label': real_labels, # the label of the input image [byte]
|
107 |
+
'img_path': 'img_path',
|
108 |
+
'idx': 'indexes',
|
109 |
+
'wcl': index, # id of the author [int],
|
110 |
+
'slabels': slabels,
|
111 |
+
'author_id': author_id
|
112 |
+
}
|
113 |
+
return item
|
114 |
+
|
115 |
+
def get_stats(self):
|
116 |
+
char_counts = defaultdict(lambda: 0)
|
117 |
+
total = 0
|
118 |
+
|
119 |
+
for author in self.IMG_DATA.keys():
|
120 |
+
for data in self.IMG_DATA[author]:
|
121 |
+
for char in data['label']:
|
122 |
+
char_counts[char] += 1
|
123 |
+
total += 1
|
124 |
+
|
125 |
+
char_counts = {k: 1.0 / (v / total) for k, v in char_counts.items()}
|
126 |
+
|
127 |
+
return char_counts
|
128 |
+
|
129 |
+
|
130 |
+
class TextCollator(object):
|
131 |
+
def __init__(self, resolution):
|
132 |
+
self.resolution = resolution
|
133 |
+
|
134 |
+
def __call__(self, batch):
|
135 |
+
if isinstance(batch[0], list):
|
136 |
+
batch = sum(batch, [])
|
137 |
+
img_path = [item['img_path'] for item in batch]
|
138 |
+
width = [item['img'].shape[2] for item in batch]
|
139 |
+
indexes = [item['idx'] for item in batch]
|
140 |
+
simgs = torch.stack([item['simg'] for item in batch], 0)
|
141 |
+
wcls = torch.Tensor([item['wcl'] for item in batch])
|
142 |
+
swids = torch.Tensor([item['swids'] for item in batch])
|
143 |
+
imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)],
|
144 |
+
dtype=torch.float32)
|
145 |
+
for idx, item in enumerate(batch):
|
146 |
+
try:
|
147 |
+
imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
|
148 |
+
except:
|
149 |
+
print(imgs.shape)
|
150 |
+
item = {'img': imgs, 'img_path': img_path, 'idx': indexes, 'simg': simgs, 'swids': swids, 'wcl': wcls}
|
151 |
+
if 'label' in batch[0].keys():
|
152 |
+
labels = [item['label'] for item in batch]
|
153 |
+
item['label'] = labels
|
154 |
+
if 'slabels' in batch[0].keys():
|
155 |
+
slabels = [item['slabels'] for item in batch]
|
156 |
+
item['slabels'] = np.array(slabels)
|
157 |
+
if 'z' in batch[0].keys():
|
158 |
+
z = torch.stack([item['z'] for item in batch])
|
159 |
+
item['z'] = z
|
160 |
+
return item
|
161 |
+
|
162 |
+
|
163 |
+
class CollectionTextDataset(Dataset):
|
164 |
+
def __init__(self, datasets, datasets_path, dataset_class, file_suffix=None, height=32, **kwargs):
|
165 |
+
self.datasets = {}
|
166 |
+
for dataset_name in sorted(datasets.split(',')):
|
167 |
+
dataset_file = get_dataset_path(dataset_name, height, file_suffix, datasets_path)
|
168 |
+
dataset = dataset_class(dataset_file, **kwargs)
|
169 |
+
self.datasets[dataset_name] = dataset
|
170 |
+
self.alphabet = ''.join(sorted(set(''.join(d.alphabet for d in self.datasets.values()))))
|
171 |
+
|
172 |
+
def __len__(self):
|
173 |
+
return sum(len(d) for d in self.datasets.values())
|
174 |
+
|
175 |
+
@property
|
176 |
+
def num_writers(self):
|
177 |
+
return sum(d.num_writers for d in self.datasets.values())
|
178 |
+
|
179 |
+
def __getitem__(self, index):
|
180 |
+
for dataset in self.datasets.values():
|
181 |
+
if index < len(dataset):
|
182 |
+
return dataset[index]
|
183 |
+
index -= len(dataset)
|
184 |
+
raise IndexError
|
185 |
+
|
186 |
+
def get_dataset(self, index):
|
187 |
+
for dataset_name, dataset in self.datasets.items():
|
188 |
+
if index < len(dataset):
|
189 |
+
return dataset_name
|
190 |
+
index -= len(dataset)
|
191 |
+
raise IndexError
|
192 |
+
|
193 |
+
def collate_fn(self, batch):
|
194 |
+
return self.datasets[self.get_dataset(0)].collate_fn(batch)
|
195 |
+
|
196 |
+
|
197 |
+
class FidDataset(Dataset):
|
198 |
+
def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, mode='train', style_dataset=None):
|
199 |
+
self.NUM_EXAMPLES = num_examples
|
200 |
+
|
201 |
+
# base_path=DATASET_PATHS
|
202 |
+
with open(base_path, "rb") as f:
|
203 |
+
self.IMG_DATA = pickle.load(f)
|
204 |
+
|
205 |
+
self.IMG_DATA = self.IMG_DATA[mode]
|
206 |
+
if 'None' in self.IMG_DATA.keys():
|
207 |
+
del self.IMG_DATA['None']
|
208 |
+
|
209 |
+
self.STYLE_IMG_DATA = None
|
210 |
+
if style_dataset is not None:
|
211 |
+
with open(style_dataset, "rb") as f:
|
212 |
+
self.STYLE_IMG_DATA = pickle.load(f)
|
213 |
+
|
214 |
+
self.STYLE_IMG_DATA = self.STYLE_IMG_DATA[mode]
|
215 |
+
if 'None' in self.STYLE_IMG_DATA.keys():
|
216 |
+
del self.STYLE_IMG_DATA['None']
|
217 |
+
|
218 |
+
self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), [])))))
|
219 |
+
self.author_id = sorted(self.IMG_DATA.keys())
|
220 |
+
|
221 |
+
self.transform = get_transform(grayscale=True)
|
222 |
+
self.target_transform = target_transform
|
223 |
+
self.dataset_size = sum(len(samples) for samples in self.IMG_DATA.values())
|
224 |
+
self.collate_fn = TextCollator(collator_resolution)
|
225 |
+
|
226 |
+
def __len__(self):
|
227 |
+
return self.dataset_size
|
228 |
+
|
229 |
+
@property
|
230 |
+
def num_writers(self):
|
231 |
+
return len(self.author_id)
|
232 |
+
|
233 |
+
def __getitem__(self, index):
|
234 |
+
NUM_SAMPLES = self.NUM_EXAMPLES
|
235 |
+
sample, author_id = None, None
|
236 |
+
for author_id, samples in self.IMG_DATA.items():
|
237 |
+
if index < len(samples):
|
238 |
+
sample, author_id = samples[index], author_id
|
239 |
+
break
|
240 |
+
index -= len(samples)
|
241 |
+
|
242 |
+
real_image = self.transform(sample['img'].convert('L'))
|
243 |
+
real_label = sample['label'].encode()
|
244 |
+
|
245 |
+
style_dataset = self.STYLE_IMG_DATA if self.STYLE_IMG_DATA is not None else self.IMG_DATA
|
246 |
+
|
247 |
+
author_style_images = style_dataset[author_id]
|
248 |
+
random_idxs = np.random.choice(len(author_style_images), NUM_SAMPLES, replace=True)
|
249 |
+
style_images = [np.array(author_style_images[idx]['img'].convert('L')) for idx in random_idxs]
|
250 |
+
|
251 |
+
max_width = 192
|
252 |
+
|
253 |
+
imgs_pad = []
|
254 |
+
imgs_wids = []
|
255 |
+
|
256 |
+
for img in style_images:
|
257 |
+
img = 255 - img
|
258 |
+
img_height, img_width = img.shape[0], img.shape[1]
|
259 |
+
outImg = np.zeros((img_height, max_width), dtype='float32')
|
260 |
+
outImg[:, :img_width] = img[:, :max_width]
|
261 |
+
|
262 |
+
img = 255 - outImg
|
263 |
+
|
264 |
+
imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8))))
|
265 |
+
imgs_wids.append(img_width)
|
266 |
+
|
267 |
+
imgs_pad = torch.cat(imgs_pad, 0)
|
268 |
+
|
269 |
+
item = {
|
270 |
+
'simg': imgs_pad, # widths of the N images [list(N)]
|
271 |
+
'swids': imgs_wids, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)]
|
272 |
+
'img': real_image, # the input image [1, H (32), W]
|
273 |
+
'label': real_label, # the label of the input image [byte]
|
274 |
+
'img_path': 'img_path',
|
275 |
+
'idx': sample['img_id'] if 'img_id' in sample.keys() else sample['image_id'],
|
276 |
+
'wcl': int(author_id) # id of the author [int]
|
277 |
+
}
|
278 |
+
return item
|
279 |
+
|
280 |
+
|
281 |
+
class FolderDataset:
|
282 |
+
def __init__(self, folder_path, num_examples=15, word_lengths=None):
|
283 |
+
folder_path = Path(folder_path)
|
284 |
+
self.imgs = list([p for p in folder_path.iterdir() if not p.suffix == '.txt'])
|
285 |
+
self.transform = get_transform(grayscale=True)
|
286 |
+
self.num_examples = num_examples
|
287 |
+
self.word_lengths = word_lengths
|
288 |
+
|
289 |
+
def __len__(self):
|
290 |
+
return len(self.imgs)
|
291 |
+
|
292 |
+
def sample_style(self):
|
293 |
+
random_idxs = np.random.choice(len(self.imgs), self.num_examples, replace=False)
|
294 |
+
image_names = [self.imgs[idx].stem for idx in random_idxs]
|
295 |
+
imgs = [Image.open(self.imgs[idx]).convert('L') for idx in random_idxs]
|
296 |
+
if self.word_lengths is None:
|
297 |
+
imgs = [img.resize((img.size[0] * 32 // img.size[1], 32), Image.BILINEAR) for img in imgs]
|
298 |
+
else:
|
299 |
+
imgs = [img.resize((self.word_lengths[name] * 16, 32), Image.BILINEAR) for img, name in zip(imgs, image_names)]
|
300 |
+
imgs = [np.array(img) for img in imgs]
|
301 |
+
|
302 |
+
max_width = 192 # [img.shape[1] for img in imgs]
|
303 |
+
|
304 |
+
imgs_pad = []
|
305 |
+
imgs_wids = []
|
306 |
+
|
307 |
+
for img in imgs:
|
308 |
+
img = 255 - img
|
309 |
+
img_height, img_width = img.shape[0], img.shape[1]
|
310 |
+
outImg = np.zeros((img_height, max_width), dtype='float32')
|
311 |
+
outImg[:, :img_width] = img[:, :max_width]
|
312 |
+
|
313 |
+
img = 255 - outImg
|
314 |
+
|
315 |
+
imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8))))
|
316 |
+
imgs_wids.append(img_width)
|
317 |
+
|
318 |
+
imgs_pad = torch.cat(imgs_pad, 0)
|
319 |
+
|
320 |
+
item = {
|
321 |
+
'simg': imgs_pad, # widths of the N images [list(N)]
|
322 |
+
'swids': imgs_wids, # N images (15) that come from the same author [N (15), H (32), MAX_W (192)]
|
323 |
+
}
|
324 |
+
return item
|
data/iam_test.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def test_split():
|
5 |
+
iam_path = r"C:\Users\bramv\Documents\Werk\Research\Unimore\datasets\IAM"
|
6 |
+
|
7 |
+
original_set_names = ["trainset.txt", "validationset1.txt", "validationset2.txt", "testset.txt"]
|
8 |
+
original_set_ids = []
|
9 |
+
|
10 |
+
print("ORIGINAL IAM")
|
11 |
+
print("---------------------")
|
12 |
+
|
13 |
+
for set_name in original_set_names:
|
14 |
+
with open(os.path.join(iam_path, set_name), 'r') as f:
|
15 |
+
set_form_ids = ["-".join(l.rstrip().split("-")[:-1]) for l in f]
|
16 |
+
|
17 |
+
form_to_id = {}
|
18 |
+
with open(os.path.join(iam_path, "forms.txt"), 'r') as f:
|
19 |
+
for line in f:
|
20 |
+
if line.startswith("#"):
|
21 |
+
continue
|
22 |
+
form, id, *_ = line.split(" ")
|
23 |
+
assert form not in form_to_id.keys() or form_to_id[form] == id
|
24 |
+
form_to_id[form] = int(id)
|
25 |
+
|
26 |
+
set_authors = [form_to_id[form] for form in set_form_ids]
|
27 |
+
|
28 |
+
set_authors = set(sorted(set_authors))
|
29 |
+
original_set_ids.append(set_authors)
|
30 |
+
print(f"{set_name} count: {len(set_authors)}")
|
31 |
+
|
32 |
+
htg_set_names = ["gan.iam.tr_va.gt.filter27", "gan.iam.test.gt.filter27"]
|
33 |
+
|
34 |
+
print("\n\nHTG IAM")
|
35 |
+
print("---------------------")
|
36 |
+
|
37 |
+
for set_name in htg_set_names:
|
38 |
+
with open(os.path.join(iam_path, set_name), 'r') as f:
|
39 |
+
set_authors = [int(l.split(",")[0]) for l in f]
|
40 |
+
|
41 |
+
set_authors = set(set_authors)
|
42 |
+
|
43 |
+
print(f"{set_name} count: {len(set_authors)}")
|
44 |
+
for name, original_set in zip(original_set_names, original_set_ids):
|
45 |
+
intr = set_authors.intersection(original_set)
|
46 |
+
print(f"\t intersection with {name}: {len(intr)}")
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
test_split()
|
data/show_dataset.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from data.dataset import get_transform
|
11 |
+
|
12 |
+
|
13 |
+
def summarize_dataset(data: dict):
|
14 |
+
print(f"Training authors: {len(data['train'].keys())} \t Testing authors: {len(data['test'].keys())}")
|
15 |
+
training_images = sum([len(data['train'][k]) for k in data['train'].keys()])
|
16 |
+
testing_images = sum([len(data['test'][k]) for k in data['test'].keys()])
|
17 |
+
print(f"Training images: {training_images} \t Testing images: {testing_images}")
|
18 |
+
|
19 |
+
|
20 |
+
def compare_data(path_a: str, path_b: str):
|
21 |
+
with open(path_a, 'rb') as f:
|
22 |
+
data_a = pickle.load(f)
|
23 |
+
summarize_dataset(data_a)
|
24 |
+
|
25 |
+
with open(path_b, 'rb') as f:
|
26 |
+
data_b = pickle.load(f)
|
27 |
+
summarize_dataset(data_b)
|
28 |
+
|
29 |
+
training_a = data_a['train']
|
30 |
+
training_b = data_b['train']
|
31 |
+
|
32 |
+
training_a = {int(k): v for k, v in training_a.items()}
|
33 |
+
training_b = {int(k): v for k, v in training_b.items()}
|
34 |
+
|
35 |
+
while True:
|
36 |
+
author = random.choice(list(training_a.keys()))
|
37 |
+
|
38 |
+
if author in training_b.keys():
|
39 |
+
author_images_a = [np.array(im_dict["img"]) for im_dict in training_a[author]]
|
40 |
+
author_images_b = [np.array(im_dict["img"]) for im_dict in training_b[author]]
|
41 |
+
|
42 |
+
labels_a = [str(im_dict["label"]) for im_dict in training_a[author]]
|
43 |
+
labels_b = [str(im_dict["label"]) for im_dict in training_b[author]]
|
44 |
+
|
45 |
+
vis_a = np.hstack(author_images_a[:10])
|
46 |
+
vis_b = np.hstack(author_images_b[:10])
|
47 |
+
|
48 |
+
cv2.imshow("Author a", vis_a)
|
49 |
+
cv2.imshow("Author b", vis_b)
|
50 |
+
|
51 |
+
cv2.waitKey(0)
|
52 |
+
|
53 |
+
else:
|
54 |
+
print(f"Author: {author} not found in second dataset")
|
55 |
+
|
56 |
+
|
57 |
+
def show_dataset(path: str, samples: int = 10):
|
58 |
+
with open(path, 'rb') as f:
|
59 |
+
data = pickle.load(f)
|
60 |
+
summarize_dataset(data)
|
61 |
+
|
62 |
+
training = data['train']
|
63 |
+
|
64 |
+
author = training['013']
|
65 |
+
author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in author]
|
66 |
+
|
67 |
+
for img in author_images:
|
68 |
+
cv2.imshow('image', img)
|
69 |
+
cv2.waitKey(0)
|
70 |
+
|
71 |
+
for author in list(training.keys()):
|
72 |
+
|
73 |
+
author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in training[author]]
|
74 |
+
labels = [str(im_dict["label"]) for im_dict in training[author]]
|
75 |
+
|
76 |
+
vis = np.hstack(author_images[:samples])
|
77 |
+
print(f"Author: {author}")
|
78 |
+
cv2.destroyAllWindows()
|
79 |
+
cv2.imshow("vis", vis)
|
80 |
+
cv2.waitKey(0)
|
81 |
+
|
82 |
+
|
83 |
+
def test_transform(path: str):
|
84 |
+
with open(path, 'rb') as f:
|
85 |
+
data = pickle.load(f)
|
86 |
+
summarize_dataset(data)
|
87 |
+
|
88 |
+
training = data['train']
|
89 |
+
transform = get_transform(grayscale=True)
|
90 |
+
|
91 |
+
for author_id in training.keys():
|
92 |
+
author = training[author_id]
|
93 |
+
for image_dict in author:
|
94 |
+
original_image = image_dict['img'].convert('L')
|
95 |
+
transformed_image = transform(original_image).detach().numpy()
|
96 |
+
restored_image = (((transformed_image + 1) / 2) * 255).astype(np.uint8)
|
97 |
+
restored_image = np.squeeze(restored_image)
|
98 |
+
original_image = np.array(original_image)
|
99 |
+
|
100 |
+
wrong_pixels = (original_image != restored_image).astype(np.uint8) * 255
|
101 |
+
|
102 |
+
combined = np.hstack((restored_image, original_image, wrong_pixels))
|
103 |
+
|
104 |
+
cv2.imshow("original", original_image)
|
105 |
+
cv2.imshow("restored", restored_image)
|
106 |
+
cv2.imshow("combined", combined)
|
107 |
+
|
108 |
+
f, ax = plt.subplots(1, 2)
|
109 |
+
ax[0].hist(original_image.flatten())
|
110 |
+
ax[1].hist(restored_image.flatten())
|
111 |
+
plt.show()
|
112 |
+
|
113 |
+
cv2.waitKey(0)
|
114 |
+
|
115 |
+
def dump_words():
|
116 |
+
data_path = r"..\files\IAM-32.pickle"
|
117 |
+
|
118 |
+
p_mark = 'point'
|
119 |
+
p = '.'
|
120 |
+
|
121 |
+
with open(data_path, 'rb') as f:
|
122 |
+
data = pickle.load(f)
|
123 |
+
|
124 |
+
training = data['train']
|
125 |
+
|
126 |
+
target_folder = f"../saved_images/debug/{p_mark}"
|
127 |
+
|
128 |
+
if os.path.exists(target_folder):
|
129 |
+
shutil.rmtree(target_folder)
|
130 |
+
|
131 |
+
os.mkdir(target_folder)
|
132 |
+
|
133 |
+
count = 0
|
134 |
+
|
135 |
+
for author in list(training.keys()):
|
136 |
+
|
137 |
+
author_images = [np.array(im_dict["img"]).astype(np.uint8) for im_dict in training[author]]
|
138 |
+
labels = [str(im_dict["label"]) for im_dict in training[author]]
|
139 |
+
|
140 |
+
for img, label in zip(author_images, labels):
|
141 |
+
if p in label:
|
142 |
+
cv2.imwrite(os.path.join(target_folder, f"{count}.png"), img)
|
143 |
+
count += 1
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
test_transform("../files/IAM-32.pickle")
|
148 |
+
#show_dataset("../files/IAM-32.pickle")
|
149 |
+
#compare_data(r"../files/IAM-32.pickle", r"../files/_IAM-32.pickle")
|
files/IAM-32-pa.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92bff8330e8f404b5f382846266257b5cac45d6c27908df5c3ee7d0c77a0ee95
|
3 |
+
size 245981914
|
files/IAM-32.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c56d4055470c26a30dbbdf7f2e232eb86ffc714b803651dbac5576ee2bc97937
|
3 |
+
size 590113103
|
files/cvl_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b47fe3ffe291bb3e52db0643125a99206840884181ed21312bcbe2cdd86303f0
|
3 |
+
size 163050271
|
files/english_words.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
files/files
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
files
|
files/hwt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:999f85148e34e30242c1aa9ed7063c9dbc9da008f868ed26cb6ed923f9d8c0bd
|
3 |
+
size 163050271
|
files/resnet_18_pretrained.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf5f5f6a94152dc4b0e9f2e390d658ef621efead3824cd494d3a82a6c8ceb5e0
|
3 |
+
size 48833885
|
files/unifont.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0804979068f0d169b343fbe0fe8d7ff478165d07a671fcf52e20f625db8e7f9f
|
3 |
+
size 16978300
|
files/vatr.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65b67f1738bf74d5bf612f7f35e2c8c9560568d7efe422beb9132e1bb68bbef8
|
3 |
+
size 565758212
|
files/vatrpp.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c02f950d19cf3df3cfa6fe97114557e16a51bd3b910da6b5a2359a29851b84b6
|
3 |
+
size 561198056
|
generate.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from generate import generate_text, generate_authors, generate_fid, generate_page, generate_ocr, generate_ocr_msgpack
|
3 |
+
from generate.ocr import generate_ocr_reference
|
4 |
+
from util.misc import add_vatr_args
|
5 |
+
|
6 |
+
if __name__ == '__main__':
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument("action", choices=['text', 'fid', 'page', 'authors', 'ocr'])
|
9 |
+
|
10 |
+
parser.add_argument("-s", "--style-folder", default='files/style_samples/00', type=str)
|
11 |
+
parser.add_argument("-t", "--text", default='That\'s one small step for man, one giant leap for mankind ΑαΒβΓγΔδ', type=str)
|
12 |
+
parser.add_argument("--text-path", default=None, type=str, help='Path to text file with texts to generate')
|
13 |
+
parser.add_argument("-c", "--checkpoint", default='files/vatr.pth', type=str)
|
14 |
+
parser.add_argument("-o", "--output", default=None, type=str)
|
15 |
+
parser.add_argument("--count", default=1000, type=int)
|
16 |
+
parser.add_argument("-a", "--align", action='store_true')
|
17 |
+
parser.add_argument("--at-once", action='store_true')
|
18 |
+
parser.add_argument("--output-style", action='store_true')
|
19 |
+
parser.add_argument("-d", "--dataset-path", type=str)
|
20 |
+
parser.add_argument("--target-dataset-path", type=str, default=None)
|
21 |
+
parser.add_argument("--charset-file", type=str, default=None)
|
22 |
+
parser.add_argument("--interp-styles", action='store_true')
|
23 |
+
|
24 |
+
parser.add_argument("--test-only", action='store_true')
|
25 |
+
parser.add_argument("--fake-only", action='store_true')
|
26 |
+
parser.add_argument("--all-epochs", action='store_true')
|
27 |
+
parser.add_argument("--long-tail", action='store_true')
|
28 |
+
parser.add_argument("--msgpack", action='store_true')
|
29 |
+
parser.add_argument("--reference", action='store_true')
|
30 |
+
parser.add_argument("--test-set", action='store_true')
|
31 |
+
|
32 |
+
parser = add_vatr_args(parser)
|
33 |
+
args = parser.parse_args()
|
34 |
+
|
35 |
+
if args.action == 'text':
|
36 |
+
generate_text(args)
|
37 |
+
elif args.action == 'authors':
|
38 |
+
generate_authors(args)
|
39 |
+
elif args.action == 'fid':
|
40 |
+
generate_fid(args)
|
41 |
+
elif args.action == 'page':
|
42 |
+
generate_page(args)
|
43 |
+
elif args.action == 'ocr':
|
44 |
+
if args.msgpack:
|
45 |
+
generate_ocr_msgpack(args)
|
46 |
+
elif args.reference:
|
47 |
+
generate_ocr_reference(args)
|
48 |
+
else:
|
49 |
+
generate_ocr(args)
|
generate/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from generate.text import generate_text
|
2 |
+
from generate.fid import generate_fid
|
3 |
+
from generate.authors import generate_authors
|
4 |
+
from generate.page import generate_page
|
5 |
+
from generate.ocr import generate_ocr, generate_ocr_msgpack
|
generate/authors.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from data.dataset import CollectionTextDataset, TextDataset
|
8 |
+
from generate.util import stack_lines
|
9 |
+
from generate.writer import Writer
|
10 |
+
|
11 |
+
|
12 |
+
def generate_authors(args):
|
13 |
+
dataset = CollectionTextDataset(
|
14 |
+
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
|
15 |
+
collator_resolution=args.resolution, validation=args.test_set
|
16 |
+
)
|
17 |
+
|
18 |
+
args.num_writers = dataset.num_writers
|
19 |
+
|
20 |
+
writer = Writer(args.checkpoint, args, only_generator=True)
|
21 |
+
|
22 |
+
if args.text.endswith(".txt"):
|
23 |
+
with open(args.text, 'r') as f:
|
24 |
+
lines = [l.rstrip() for l in f]
|
25 |
+
else:
|
26 |
+
lines = [args.text]
|
27 |
+
|
28 |
+
output_dir = "saved_images/author_samples/"
|
29 |
+
if os.path.exists(output_dir):
|
30 |
+
shutil.rmtree(output_dir)
|
31 |
+
os.mkdir(output_dir)
|
32 |
+
|
33 |
+
fakes, author_ids, style_images = writer.generate_authors(lines, dataset, args.align, args.at_once)
|
34 |
+
|
35 |
+
for fake, author_id, style in zip(fakes, author_ids, style_images):
|
36 |
+
author_dir = os.path.join(output_dir, str(author_id))
|
37 |
+
os.mkdir(author_dir)
|
38 |
+
|
39 |
+
for i, line in enumerate(fake):
|
40 |
+
cv2.imwrite(os.path.join(author_dir, f"line_{i}.png"), line)
|
41 |
+
|
42 |
+
total = stack_lines(fake)
|
43 |
+
cv2.imwrite(os.path.join(author_dir, "total.png"), total)
|
44 |
+
|
45 |
+
if args.output_style:
|
46 |
+
for i, image in enumerate(style):
|
47 |
+
cv2.imwrite(os.path.join(author_dir, f"style_{i}.png"), image)
|
48 |
+
|
generate/fid.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
|
7 |
+
from data.dataset import FidDataset
|
8 |
+
from generate.writer import Writer
|
9 |
+
|
10 |
+
|
11 |
+
def generate_fid(args):
|
12 |
+
if 'iam' in args.target_dataset_path.lower():
|
13 |
+
args.num_writers = 339
|
14 |
+
elif 'cvl' in args.target_dataset_path.lower():
|
15 |
+
args.num_writers = 283
|
16 |
+
else:
|
17 |
+
raise ValueError
|
18 |
+
|
19 |
+
args.vocab_size = len(args.alphabet)
|
20 |
+
|
21 |
+
dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path)
|
22 |
+
train_loader = torch.utils.data.DataLoader(
|
23 |
+
dataset_train,
|
24 |
+
batch_size=args.batch_size,
|
25 |
+
shuffle=False,
|
26 |
+
num_workers=args.num_workers,
|
27 |
+
pin_memory=True, drop_last=False,
|
28 |
+
collate_fn=dataset_train.collate_fn
|
29 |
+
)
|
30 |
+
|
31 |
+
dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path)
|
32 |
+
test_loader = torch.utils.data.DataLoader(
|
33 |
+
dataset_test,
|
34 |
+
batch_size=args.batch_size,
|
35 |
+
shuffle=False,
|
36 |
+
num_workers=0,
|
37 |
+
pin_memory=True, drop_last=False,
|
38 |
+
collate_fn=dataset_test.collate_fn
|
39 |
+
)
|
40 |
+
|
41 |
+
args.output = 'saved_images' if args.output is None else args.output
|
42 |
+
args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")
|
43 |
+
|
44 |
+
model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1]
|
45 |
+
model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr"
|
46 |
+
model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")
|
47 |
+
|
48 |
+
if not args.all_epochs:
|
49 |
+
writer = Writer(args.checkpoint, args, only_generator=True)
|
50 |
+
if not args.test_only:
|
51 |
+
writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail)
|
52 |
+
writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail)
|
53 |
+
else:
|
54 |
+
epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f])
|
55 |
+
generate_real = True
|
56 |
+
|
57 |
+
for epoch in epochs:
|
58 |
+
checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth")
|
59 |
+
writer = Writer(checkpoint_path, args, only_generator=True)
|
60 |
+
writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail)
|
61 |
+
generate_real = False
|
62 |
+
|
63 |
+
print('Done')
|
generate/ocr.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import msgpack
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from data.dataset import CollectionTextDataset, TextDataset, FolderDataset, FidDataset, get_dataset_path
|
9 |
+
from generate.writer import Writer
|
10 |
+
from util.text import get_generator
|
11 |
+
|
12 |
+
|
13 |
+
def generate_ocr(args):
|
14 |
+
"""
|
15 |
+
Generate OCR training data. Words generated are from given text generator.
|
16 |
+
"""
|
17 |
+
dataset = CollectionTextDataset(
|
18 |
+
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
|
19 |
+
collator_resolution=args.resolution, validation=True
|
20 |
+
)
|
21 |
+
args.num_writers = dataset.num_writers
|
22 |
+
|
23 |
+
writer = Writer(args.checkpoint, args, only_generator=True)
|
24 |
+
|
25 |
+
generator = get_generator(args)
|
26 |
+
|
27 |
+
writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, text_generator=generator)
|
28 |
+
|
29 |
+
|
30 |
+
def generate_ocr_reference(args):
|
31 |
+
"""
|
32 |
+
Generate OCR training data. Words generated are words from given dataset. Reference words are also saved.
|
33 |
+
"""
|
34 |
+
dataset = CollectionTextDataset(
|
35 |
+
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
|
36 |
+
collator_resolution=args.resolution, validation=True
|
37 |
+
)
|
38 |
+
|
39 |
+
#dataset = FidDataset(get_dataset_path(args.dataset, 32, args.file_suffix, 'files'), mode='test', collator_resolution=args.resolution)
|
40 |
+
|
41 |
+
args.num_writers = dataset.num_writers
|
42 |
+
|
43 |
+
writer = Writer(args.checkpoint, args, only_generator=True)
|
44 |
+
|
45 |
+
writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, long_tail=args.long_tail)
|
46 |
+
|
47 |
+
|
48 |
+
def generate_ocr_msgpack(args):
|
49 |
+
"""
|
50 |
+
Generate OCR dataset. Words generated are specified in given msgpack file
|
51 |
+
"""
|
52 |
+
dataset = FolderDataset(args.dataset_path)
|
53 |
+
args.num_writers = 339
|
54 |
+
|
55 |
+
if args.charset_file:
|
56 |
+
charset = msgpack.load(open(args.charset_file, 'rb'), use_list=False, strict_map_key=False)
|
57 |
+
args.alphabet = "".join(charset['char2idx'].keys())
|
58 |
+
|
59 |
+
writer = Writer(args.checkpoint, args, only_generator=True)
|
60 |
+
|
61 |
+
lines = msgpack.load(open(args.text_path, 'rb'), use_list=False)
|
62 |
+
|
63 |
+
print(f"Generating {len(lines)} to {args.output}")
|
64 |
+
|
65 |
+
for i, (filename, target) in enumerate(lines):
|
66 |
+
if not os.path.exists(os.path.join(args.output, filename)):
|
67 |
+
style = torch.unsqueeze(dataset.sample_style()['simg'], dim=0).to(args.device)
|
68 |
+
fake = writer.create_fake_sentence(style, target, at_once=True)
|
69 |
+
|
70 |
+
cv2.imwrite(os.path.join(args.output, filename), fake)
|
71 |
+
|
72 |
+
print(f"Done")
|
generate/page.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from data.dataset import CollectionTextDataset, TextDataset
|
8 |
+
from models.model import VATr
|
9 |
+
from util.loading import load_checkpoint, load_generator
|
10 |
+
|
11 |
+
|
12 |
+
def generate_page(args):
|
13 |
+
args.output = 'vatr' if args.output is None else args.output
|
14 |
+
|
15 |
+
args.vocab_size = len(args.alphabet)
|
16 |
+
|
17 |
+
dataset = CollectionTextDataset(
|
18 |
+
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
|
19 |
+
collator_resolution=args.resolution
|
20 |
+
)
|
21 |
+
datasetval = CollectionTextDataset(
|
22 |
+
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
|
23 |
+
collator_resolution=args.resolution, validation=True
|
24 |
+
)
|
25 |
+
|
26 |
+
args.num_writers = dataset.num_writers
|
27 |
+
|
28 |
+
model = VATr(args)
|
29 |
+
checkpoint = torch.load(args.checkpoint, map_location=args.device)
|
30 |
+
model = load_generator(model, checkpoint)
|
31 |
+
|
32 |
+
train_loader = torch.utils.data.DataLoader(
|
33 |
+
dataset,
|
34 |
+
batch_size=8,
|
35 |
+
shuffle=True,
|
36 |
+
num_workers=0,
|
37 |
+
pin_memory=True, drop_last=True,
|
38 |
+
collate_fn=dataset.collate_fn)
|
39 |
+
|
40 |
+
val_loader = torch.utils.data.DataLoader(
|
41 |
+
datasetval,
|
42 |
+
batch_size=8,
|
43 |
+
shuffle=True,
|
44 |
+
num_workers=0,
|
45 |
+
pin_memory=True, drop_last=True,
|
46 |
+
collate_fn=datasetval.collate_fn)
|
47 |
+
|
48 |
+
data_train = next(iter(train_loader))
|
49 |
+
data_val = next(iter(val_loader))
|
50 |
+
|
51 |
+
model.eval()
|
52 |
+
with torch.no_grad():
|
53 |
+
page = model._generate_page(data_train['simg'].to(args.device), data_val['swids'])
|
54 |
+
page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids'])
|
55 |
+
|
56 |
+
cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_train.png"), (page * 255).astype(np.uint8))
|
57 |
+
cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_val.png"), (page_val * 255).astype(np.uint8))
|
generate/text.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
from generate.writer import Writer
|
6 |
+
|
7 |
+
|
8 |
+
def generate_text(args):
|
9 |
+
if args.text_path is not None:
|
10 |
+
with open(args.text_path, 'r') as f:
|
11 |
+
args.text = f.read()
|
12 |
+
args.text = args.text.splitlines()
|
13 |
+
args.output = 'files/output.png' if args.output is None else args.output
|
14 |
+
args.output = Path(args.output)
|
15 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
16 |
+
args.num_writers = 0
|
17 |
+
|
18 |
+
writer = Writer(args.checkpoint, args, only_generator=True)
|
19 |
+
writer.set_style_folder(args.style_folder)
|
20 |
+
fakes = writer.generate(args.text, args.align)
|
21 |
+
for i, fake in enumerate(fakes):
|
22 |
+
dst_path = args.output.parent / (args.output.stem + f'_{i:03d}' + args.output.suffix)
|
23 |
+
cv2.imwrite(str(dst_path), fake)
|
24 |
+
print('Done')
|
generate/util.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def stack_lines(lines: list, h_gap: int = 6):
|
5 |
+
width = max([im.shape[1] for im in lines])
|
6 |
+
height = (lines[0].shape[0] + h_gap) * len(lines)
|
7 |
+
|
8 |
+
result = np.ones((height, width)) * 255
|
9 |
+
|
10 |
+
y_pos = 0
|
11 |
+
for line in lines:
|
12 |
+
result[y_pos:y_pos + line.shape[0], 0:line.shape[1]] = line
|
13 |
+
y_pos += line.shape[0] + h_gap
|
14 |
+
|
15 |
+
return result
|
generate/writer.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
from collections import defaultdict
|
6 |
+
import time
|
7 |
+
from datetime import timedelta
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from data.dataset import FolderDataset
|
15 |
+
from models.model import VATr
|
16 |
+
from util.loading import load_checkpoint, load_generator
|
17 |
+
from util.misc import FakeArgs
|
18 |
+
from util.text import TextGenerator
|
19 |
+
from util.vision import detect_text_bounds
|
20 |
+
|
21 |
+
|
22 |
+
def get_long_tail_chars():
|
23 |
+
with open(f"files/longtail.txt", 'r') as f:
|
24 |
+
chars = [c.rstrip() for c in f]
|
25 |
+
|
26 |
+
chars.remove('')
|
27 |
+
|
28 |
+
return chars
|
29 |
+
|
30 |
+
|
31 |
+
class Writer:
|
32 |
+
def __init__(self, checkpoint_path, args, only_generator: bool = False):
|
33 |
+
self.model = VATr(args)
|
34 |
+
checkpoint = torch.load(checkpoint_path, map_location=args.device)
|
35 |
+
load_checkpoint(self.model, checkpoint) if not only_generator else load_generator(self.model, checkpoint)
|
36 |
+
self.model.eval()
|
37 |
+
self.style_dataset = None
|
38 |
+
|
39 |
+
def set_style_folder(self, style_folder, num_examples=15):
|
40 |
+
word_lengths = None
|
41 |
+
if os.path.exists(os.path.join(style_folder, "word_lengths.txt")):
|
42 |
+
word_lengths = {}
|
43 |
+
with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f:
|
44 |
+
for line in f:
|
45 |
+
word, length = line.rstrip().split(",")
|
46 |
+
word_lengths[word] = int(length)
|
47 |
+
|
48 |
+
self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths)
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def generate(self, texts, align_words: bool = False, at_once: bool = False):
|
52 |
+
if isinstance(texts, str):
|
53 |
+
texts = [texts]
|
54 |
+
if self.style_dataset is None:
|
55 |
+
raise Exception('Style is not set')
|
56 |
+
|
57 |
+
fakes = []
|
58 |
+
for i, text in enumerate(texts, 1):
|
59 |
+
print(f'[{i}/{len(texts)}] Generating for text: {text}')
|
60 |
+
style = self.style_dataset.sample_style()
|
61 |
+
style_images = style['simg'].unsqueeze(0).to(self.model.args.device)
|
62 |
+
|
63 |
+
fake = self.create_fake_sentence(style_images, text, align_words, at_once)
|
64 |
+
|
65 |
+
fakes.append(fake)
|
66 |
+
return fakes
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def create_fake_sentence(self, style_images, text, align_words=False, at_once=False):
|
70 |
+
text = "".join([c for c in text if c in self.model.args.alphabet])
|
71 |
+
|
72 |
+
text = text.split() if not at_once else [text]
|
73 |
+
gap = np.ones((32, 16))
|
74 |
+
|
75 |
+
text_encode, len_text, encode_pos = self.model.netconverter.encode(text)
|
76 |
+
text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
|
77 |
+
|
78 |
+
fake = self.model._generate_fakes(style_images, text_encode, len_text)
|
79 |
+
if not at_once:
|
80 |
+
if align_words:
|
81 |
+
fake = self.stitch_words(fake, show_lines=False)
|
82 |
+
else:
|
83 |
+
fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16]
|
84 |
+
else:
|
85 |
+
fake = fake[0]
|
86 |
+
fake = (fake * 255).astype(np.uint8)
|
87 |
+
|
88 |
+
return fake
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False):
|
92 |
+
fakes = []
|
93 |
+
author_ids = []
|
94 |
+
style = []
|
95 |
+
|
96 |
+
for item in dataset:
|
97 |
+
print(f"Generating author {item['wcl']}")
|
98 |
+
style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
|
99 |
+
|
100 |
+
generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text]
|
101 |
+
|
102 |
+
fakes.append(generated_lines)
|
103 |
+
author_ids.append(item['author_id'])
|
104 |
+
style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8))
|
105 |
+
|
106 |
+
return fakes, author_ids, style
|
107 |
+
|
108 |
+
@torch.no_grad()
|
109 |
+
def generate_characters(self, dataset, characters: str):
|
110 |
+
"""
|
111 |
+
Generate each of the given characters for each of the authors in the dataset.
|
112 |
+
"""
|
113 |
+
fakes = []
|
114 |
+
|
115 |
+
text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters])
|
116 |
+
text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
|
117 |
+
|
118 |
+
for item in dataset:
|
119 |
+
print(f"Generating author {item['wcl']}")
|
120 |
+
style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
|
121 |
+
fake = self.model.netG.evaluate(style_images, text_encode)
|
122 |
+
|
123 |
+
fakes.append(fake)
|
124 |
+
|
125 |
+
return fakes
|
126 |
+
|
127 |
+
@torch.no_grad()
|
128 |
+
def generate_batch(self, style_imgs, text):
|
129 |
+
"""
|
130 |
+
Given a batch of style images and text, generate images using the model
|
131 |
+
"""
|
132 |
+
device = self.model.args.device
|
133 |
+
text_encode, _, _ = self.model.netconverter.encode(text)
|
134 |
+
fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device))
|
135 |
+
return fakes
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False):
|
139 |
+
def create_and_write(style, text, interpolated=False):
|
140 |
+
nonlocal image_counter, annotations
|
141 |
+
|
142 |
+
text_encode, len_text, encode_pos = self.model.netconverter.encode([text])
|
143 |
+
text_encode = text_encode.to(self.model.args.device)
|
144 |
+
|
145 |
+
fake = self.model.netG.generate(style, text_encode)
|
146 |
+
|
147 |
+
fake = (fake + 1) / 2
|
148 |
+
fake = fake.cpu().numpy()
|
149 |
+
fake = np.squeeze((fake * 255).astype(np.uint8))
|
150 |
+
|
151 |
+
image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png"
|
152 |
+
|
153 |
+
cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake)
|
154 |
+
|
155 |
+
annotations.append((image_filename, text))
|
156 |
+
|
157 |
+
image_counter += 1
|
158 |
+
|
159 |
+
image_counter = 0
|
160 |
+
annotations = []
|
161 |
+
previous_style = None
|
162 |
+
long_tail_chars = get_long_tail_chars()
|
163 |
+
|
164 |
+
os.mkdir(os.path.join(output_folder, "generated"))
|
165 |
+
if text_generator is None:
|
166 |
+
os.mkdir(os.path.join(output_folder, "reference"))
|
167 |
+
|
168 |
+
while image_counter < number:
|
169 |
+
author_index = random.randint(0, len(dataset) - 1)
|
170 |
+
item = dataset[author_index]
|
171 |
+
|
172 |
+
style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
|
173 |
+
style = self.model.netG.compute_style(style_images)
|
174 |
+
|
175 |
+
if interpolate_style and previous_style is not None:
|
176 |
+
factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0))
|
177 |
+
intermediate_style = torch.lerp(previous_style, style, factor)
|
178 |
+
text = text_generator.generate()
|
179 |
+
|
180 |
+
create_and_write(intermediate_style, text, interpolated=True)
|
181 |
+
|
182 |
+
if text_generator is not None:
|
183 |
+
text = text_generator.generate()
|
184 |
+
else:
|
185 |
+
text = str(item['label'].decode())
|
186 |
+
|
187 |
+
if long_tail and not any(c in long_tail_chars for c in text):
|
188 |
+
continue
|
189 |
+
|
190 |
+
fake = (item['img'] + 1) / 2
|
191 |
+
fake = fake.cpu().numpy()
|
192 |
+
fake = np.squeeze((fake * 255).astype(np.uint8))
|
193 |
+
|
194 |
+
image_filename = f"{image_counter}.png"
|
195 |
+
|
196 |
+
cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake)
|
197 |
+
|
198 |
+
create_and_write(style, text)
|
199 |
+
|
200 |
+
previous_style = style
|
201 |
+
|
202 |
+
if text_generator is None:
|
203 |
+
with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr:
|
204 |
+
fr.write(f"filename,words\n")
|
205 |
+
for annotation in annotations:
|
206 |
+
fr.write(f"{annotation[0]},{annotation[1]}\n")
|
207 |
+
|
208 |
+
with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg:
|
209 |
+
fg.write(f"filename,words\n")
|
210 |
+
for annotation in annotations:
|
211 |
+
fg.write(f"{annotation[0]},{annotation[1]}\n")
|
212 |
+
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False):
|
216 |
+
gap_width = 16
|
217 |
+
|
218 |
+
bottom_lines = []
|
219 |
+
top_lines = []
|
220 |
+
for i in range(len(words)):
|
221 |
+
b, t = detect_text_bounds(words[i])
|
222 |
+
bottom_lines.append(b)
|
223 |
+
top_lines.append(t)
|
224 |
+
if show_lines:
|
225 |
+
words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0))
|
226 |
+
words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0))
|
227 |
+
|
228 |
+
bottom_lines = np.array(bottom_lines, dtype=float)
|
229 |
+
|
230 |
+
if scale_words:
|
231 |
+
top_lines = np.array(top_lines, dtype=float)
|
232 |
+
gaps = bottom_lines - top_lines
|
233 |
+
target_gap = np.mean(gaps)
|
234 |
+
scales = target_gap / gaps
|
235 |
+
|
236 |
+
bottom_lines *= scales
|
237 |
+
top_lines *= scales
|
238 |
+
words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)]
|
239 |
+
|
240 |
+
highest = np.max(bottom_lines)
|
241 |
+
offsets = highest - bottom_lines
|
242 |
+
height = np.max(offsets + [word.shape[0] for word in words])
|
243 |
+
|
244 |
+
result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words])))
|
245 |
+
|
246 |
+
x_pos = 0
|
247 |
+
for bottom_line, word in zip(bottom_lines, words):
|
248 |
+
offset = int(highest - bottom_line)
|
249 |
+
|
250 |
+
result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word
|
251 |
+
|
252 |
+
x_pos += word.shape[1] + gap_width
|
253 |
+
|
254 |
+
return result
|
255 |
+
|
256 |
+
@torch.no_grad()
|
257 |
+
def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False):
|
258 |
+
if not isinstance(path, Path):
|
259 |
+
path = Path(path)
|
260 |
+
|
261 |
+
path.mkdir(exist_ok=True, parents=True)
|
262 |
+
|
263 |
+
appendix = f"{split}" if not long_tail_only else f"{split}_lt"
|
264 |
+
|
265 |
+
real_base = path / f'real_{appendix}'
|
266 |
+
fake_base = path / model_tag / f'fake_{appendix}'
|
267 |
+
|
268 |
+
if real_base.exists() and not fake_only:
|
269 |
+
shutil.rmtree(real_base)
|
270 |
+
|
271 |
+
if fake_base.exists():
|
272 |
+
shutil.rmtree(fake_base)
|
273 |
+
|
274 |
+
real_base.mkdir(exist_ok=True)
|
275 |
+
fake_base.mkdir(exist_ok=True, parents=True)
|
276 |
+
|
277 |
+
print('Saving images...')
|
278 |
+
|
279 |
+
print(' Saving images on {}'.format(str(real_base)))
|
280 |
+
print(' Saving images on {}'.format(str(fake_base)))
|
281 |
+
|
282 |
+
long_tail_chars = get_long_tail_chars()
|
283 |
+
counter = 0
|
284 |
+
ann = defaultdict(lambda: {})
|
285 |
+
start_time = time.time()
|
286 |
+
for step, data in enumerate(loader):
|
287 |
+
style_images = data['simg'].to(self.model.args.device)
|
288 |
+
|
289 |
+
texts = [l.decode('utf-8') for l in data['label']]
|
290 |
+
texts = [t.encode('utf-8') for t in texts]
|
291 |
+
eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts)
|
292 |
+
eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1)
|
293 |
+
|
294 |
+
vis_style = np.vstack(style_images[0].detach().cpu().numpy())
|
295 |
+
vis_style = ((vis_style + 1) / 2) * 255
|
296 |
+
|
297 |
+
fakes = self.model.netG.evaluate(style_images, eval_text_encode)
|
298 |
+
fake_images = torch.cat(fakes, 1).detach().cpu().numpy()
|
299 |
+
real_images = data['img'].detach().cpu().numpy()
|
300 |
+
writer_ids = data['wcl'].int().tolist()
|
301 |
+
|
302 |
+
for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])):
|
303 |
+
lb = lb.decode()
|
304 |
+
ann[f"{wid:03d}"][f'{img_id:05d}'] = lb
|
305 |
+
img_id = f'{img_id:05d}.png'
|
306 |
+
|
307 |
+
is_long_tail = any(c in long_tail_chars for c in lb)
|
308 |
+
|
309 |
+
if long_tail_only and not is_long_tail:
|
310 |
+
continue
|
311 |
+
|
312 |
+
fake_img_path = fake_base / f"{wid:03d}" / img_id
|
313 |
+
fake_img_path.parent.mkdir(exist_ok=True, parents=True)
|
314 |
+
cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2))
|
315 |
+
|
316 |
+
if not fake_only:
|
317 |
+
real_img_path = real_base / f"{wid:03d}" / img_id
|
318 |
+
real_img_path.parent.mkdir(exist_ok=True, parents=True)
|
319 |
+
cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2))
|
320 |
+
|
321 |
+
counter += 1
|
322 |
+
|
323 |
+
eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1)
|
324 |
+
eta = str(timedelta(seconds=eta))
|
325 |
+
if step % 100 == 0:
|
326 |
+
print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}')
|
327 |
+
|
328 |
+
with open(path / 'ann.json', 'w') as f:
|
329 |
+
json.dump(ann, f)
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.46.2"
|
4 |
+
}
|
hwt/config.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_noise": false,
|
3 |
+
"alphabet": "Only thewigsofrcvdampbkuq.A-210xT5'MDL,RYHJ\"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%",
|
4 |
+
"architectures": [
|
5 |
+
"VATrPP"
|
6 |
+
],
|
7 |
+
"augment_ocr": false,
|
8 |
+
"batch_size": 8,
|
9 |
+
"corpus": "standard",
|
10 |
+
"d_crop_size": null,
|
11 |
+
"d_lr": 1e-05,
|
12 |
+
"dataset": "IAM",
|
13 |
+
"device": "cuda",
|
14 |
+
"english_words_path": "files/english_words.txt",
|
15 |
+
"epochs": 100000,
|
16 |
+
"feat_model_path": "files/resnet_18_pretrained.pth",
|
17 |
+
"file_suffix": null,
|
18 |
+
"g_lr": 5e-05,
|
19 |
+
"img_height": 32,
|
20 |
+
"is_cycle": false,
|
21 |
+
"label_encoder": "default",
|
22 |
+
"model_type": "emuru",
|
23 |
+
"no_ocr_loss": false,
|
24 |
+
"no_writer_loss": false,
|
25 |
+
"num_examples": 15,
|
26 |
+
"num_words": 3,
|
27 |
+
"num_workers": 0,
|
28 |
+
"num_writers": 339,
|
29 |
+
"ocr_lr": 5e-05,
|
30 |
+
"query_input": "unifont",
|
31 |
+
"resolution": 16,
|
32 |
+
"save_model": 5,
|
33 |
+
"save_model_history": 500,
|
34 |
+
"save_model_path": "saved_models",
|
35 |
+
"seed": 742,
|
36 |
+
"special_alphabet": "\u0391\u03b1\u0392\u03b2\u0393\u03b3\u0394\u03b4\u0395\u03b5\u0396\u03b6\u0397\u03b7\u0398\u03b8\u0399\u03b9\u039a\u03ba\u039b\u03bb\u039c\u03bc\u039d\u03bd\u039e\u03be\u039f\u03bf\u03a0\u03c0\u03a1\u03c1\u03a3\u03c3\u03c2\u03a4\u03c4\u03a5\u03c5\u03a6\u03c6\u03a7\u03c7\u03a8\u03c8\u03a9\u03c9",
|
37 |
+
"tag": "debug",
|
38 |
+
"text_aug_type": "proportional",
|
39 |
+
"text_augment_strength": 0.0,
|
40 |
+
"torch_dtype": "float32",
|
41 |
+
"transformers_version": "4.46.2",
|
42 |
+
"vocab_size": 80,
|
43 |
+
"w_lr": 5e-05,
|
44 |
+
"wandb": false,
|
45 |
+
"writer_loss_weight": 1.0
|
46 |
+
}
|
hwt/generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.46.2"
|
4 |
+
}
|
hwt/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c9bd990cdfd3a2a1683af05705c1f9a17b7f58b580a33853b0d0af7c57f7f2e
|
3 |
+
size 560965208
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b1e4b7cae23652acd5c559117d06ef42fdd5317da2a5e0bc94ea44d8c0eb1ff
|
3 |
+
size 560965208
|
modeling_vatrpp.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from .configuration_vatrpp import VATrPPConfig
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import shutil
|
7 |
+
from collections import defaultdict
|
8 |
+
import time
|
9 |
+
from datetime import timedelta
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from data.dataset import FolderDataset
|
17 |
+
from models.model import VATr
|
18 |
+
from util.loading import load_checkpoint, load_generator
|
19 |
+
from util.misc import FakeArgs
|
20 |
+
from util.text import TextGenerator
|
21 |
+
from util.vision import detect_text_bounds
|
22 |
+
from torchvision.transforms.functional import to_pil_image
|
23 |
+
|
24 |
+
|
25 |
+
def get_long_tail_chars():
|
26 |
+
with open(f"files/longtail.txt", 'r') as f:
|
27 |
+
chars = [c.rstrip() for c in f]
|
28 |
+
|
29 |
+
chars.remove('')
|
30 |
+
|
31 |
+
return chars
|
32 |
+
|
33 |
+
class VATrPP(PreTrainedModel):
|
34 |
+
config_class = VATrPPConfig
|
35 |
+
|
36 |
+
def __init__(self, config: VATrPPConfig) -> None:
|
37 |
+
super().__init__(config)
|
38 |
+
self.model = VATr(config)
|
39 |
+
self.model.eval()
|
40 |
+
|
41 |
+
def set_style_folder(self, style_folder, num_examples=15):
|
42 |
+
word_lengths = None
|
43 |
+
if os.path.exists(os.path.join(style_folder, "word_lengths.txt")):
|
44 |
+
word_lengths = {}
|
45 |
+
with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f:
|
46 |
+
for line in f:
|
47 |
+
word, length = line.rstrip().split(",")
|
48 |
+
word_lengths[word] = int(length)
|
49 |
+
|
50 |
+
self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths)
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False):
|
54 |
+
style_images = style_imgs.unsqueeze(0).to(self.model.args.device)
|
55 |
+
|
56 |
+
fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once)
|
57 |
+
return to_pil_image(fake)
|
58 |
+
|
59 |
+
# @torch.no_grad()
|
60 |
+
# def generate(self, texts, align_words: bool = False, at_once: bool = False):
|
61 |
+
# if isinstance(texts, str):
|
62 |
+
# texts = [texts]
|
63 |
+
# if self.style_dataset is None:
|
64 |
+
# raise Exception('Style is not set')
|
65 |
+
|
66 |
+
# fakes = []
|
67 |
+
# for i, text in enumerate(texts, 1):
|
68 |
+
# print(f'[{i}/{len(texts)}] Generating for text: {text}')
|
69 |
+
# style = self.style_dataset.sample_style()
|
70 |
+
# style_images = style['simg'].unsqueeze(0).to(self.model.args.device)
|
71 |
+
|
72 |
+
# fake = self.create_fake_sentence(style_images, text, align_words, at_once)
|
73 |
+
|
74 |
+
# fakes.append(fake)
|
75 |
+
# return fakes
|
76 |
+
|
77 |
+
@torch.no_grad()
|
78 |
+
def create_fake_sentence(self, style_images, text, align_words=False, at_once=False):
|
79 |
+
text = "".join([c for c in text if c in self.model.args.alphabet])
|
80 |
+
|
81 |
+
text = text.split() if not at_once else [text]
|
82 |
+
gap = np.ones((32, 16))
|
83 |
+
|
84 |
+
text_encode, len_text, encode_pos = self.model.netconverter.encode(text)
|
85 |
+
text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
|
86 |
+
|
87 |
+
fake = self.model._generate_fakes(style_images, text_encode, len_text)
|
88 |
+
if not at_once:
|
89 |
+
if align_words:
|
90 |
+
fake = self.stitch_words(fake, show_lines=False)
|
91 |
+
else:
|
92 |
+
fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16]
|
93 |
+
else:
|
94 |
+
fake = fake[0]
|
95 |
+
fake = (fake * 255).astype(np.uint8)
|
96 |
+
|
97 |
+
return fake
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False):
|
101 |
+
fakes = []
|
102 |
+
author_ids = []
|
103 |
+
style = []
|
104 |
+
|
105 |
+
for item in dataset:
|
106 |
+
print(f"Generating author {item['wcl']}")
|
107 |
+
style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
|
108 |
+
|
109 |
+
generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text]
|
110 |
+
|
111 |
+
fakes.append(generated_lines)
|
112 |
+
author_ids.append(item['author_id'])
|
113 |
+
style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8))
|
114 |
+
|
115 |
+
return fakes, author_ids, style
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def generate_characters(self, dataset, characters: str):
|
119 |
+
"""
|
120 |
+
Generate each of the given characters for each of the authors in the dataset.
|
121 |
+
"""
|
122 |
+
fakes = []
|
123 |
+
|
124 |
+
text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters])
|
125 |
+
text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
|
126 |
+
|
127 |
+
for item in dataset:
|
128 |
+
print(f"Generating author {item['wcl']}")
|
129 |
+
style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
|
130 |
+
fake = self.model.netG.evaluate(style_images, text_encode)
|
131 |
+
|
132 |
+
fakes.append(fake)
|
133 |
+
|
134 |
+
return fakes
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def generate_batch(self, style_imgs, text):
|
138 |
+
"""
|
139 |
+
Given a batch of style images and text, generate images using the model
|
140 |
+
"""
|
141 |
+
device = self.model.args.device
|
142 |
+
text_encode, _, _ = self.model.netconverter.encode(text)
|
143 |
+
fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device))
|
144 |
+
return fakes
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False):
|
148 |
+
def create_and_write(style, text, interpolated=False):
|
149 |
+
nonlocal image_counter, annotations
|
150 |
+
|
151 |
+
text_encode, len_text, encode_pos = self.model.netconverter.encode([text])
|
152 |
+
text_encode = text_encode.to(self.model.args.device)
|
153 |
+
|
154 |
+
fake = self.model.netG.generate(style, text_encode)
|
155 |
+
|
156 |
+
fake = (fake + 1) / 2
|
157 |
+
fake = fake.cpu().numpy()
|
158 |
+
fake = np.squeeze((fake * 255).astype(np.uint8))
|
159 |
+
|
160 |
+
image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png"
|
161 |
+
|
162 |
+
cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake)
|
163 |
+
|
164 |
+
annotations.append((image_filename, text))
|
165 |
+
|
166 |
+
image_counter += 1
|
167 |
+
|
168 |
+
image_counter = 0
|
169 |
+
annotations = []
|
170 |
+
previous_style = None
|
171 |
+
long_tail_chars = get_long_tail_chars()
|
172 |
+
|
173 |
+
os.mkdir(os.path.join(output_folder, "generated"))
|
174 |
+
if text_generator is None:
|
175 |
+
os.mkdir(os.path.join(output_folder, "reference"))
|
176 |
+
|
177 |
+
while image_counter < number:
|
178 |
+
author_index = random.randint(0, len(dataset) - 1)
|
179 |
+
item = dataset[author_index]
|
180 |
+
|
181 |
+
style_images = item['simg'].to(self.model.args.device).unsqueeze(0)
|
182 |
+
style = self.model.netG.compute_style(style_images)
|
183 |
+
|
184 |
+
if interpolate_style and previous_style is not None:
|
185 |
+
factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0))
|
186 |
+
intermediate_style = torch.lerp(previous_style, style, factor)
|
187 |
+
text = text_generator.generate()
|
188 |
+
|
189 |
+
create_and_write(intermediate_style, text, interpolated=True)
|
190 |
+
|
191 |
+
if text_generator is not None:
|
192 |
+
text = text_generator.generate()
|
193 |
+
else:
|
194 |
+
text = str(item['label'].decode())
|
195 |
+
|
196 |
+
if long_tail and not any(c in long_tail_chars for c in text):
|
197 |
+
continue
|
198 |
+
|
199 |
+
fake = (item['img'] + 1) / 2
|
200 |
+
fake = fake.cpu().numpy()
|
201 |
+
fake = np.squeeze((fake * 255).astype(np.uint8))
|
202 |
+
|
203 |
+
image_filename = f"{image_counter}.png"
|
204 |
+
|
205 |
+
cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake)
|
206 |
+
|
207 |
+
create_and_write(style, text)
|
208 |
+
|
209 |
+
previous_style = style
|
210 |
+
|
211 |
+
if text_generator is None:
|
212 |
+
with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr:
|
213 |
+
fr.write(f"filename,words\n")
|
214 |
+
for annotation in annotations:
|
215 |
+
fr.write(f"{annotation[0]},{annotation[1]}\n")
|
216 |
+
|
217 |
+
with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg:
|
218 |
+
fg.write(f"filename,words\n")
|
219 |
+
for annotation in annotations:
|
220 |
+
fg.write(f"{annotation[0]},{annotation[1]}\n")
|
221 |
+
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False):
|
225 |
+
gap_width = 16
|
226 |
+
|
227 |
+
bottom_lines = []
|
228 |
+
top_lines = []
|
229 |
+
for i in range(len(words)):
|
230 |
+
b, t = detect_text_bounds(words[i])
|
231 |
+
bottom_lines.append(b)
|
232 |
+
top_lines.append(t)
|
233 |
+
if show_lines:
|
234 |
+
words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0))
|
235 |
+
words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0))
|
236 |
+
|
237 |
+
bottom_lines = np.array(bottom_lines, dtype=float)
|
238 |
+
|
239 |
+
if scale_words:
|
240 |
+
top_lines = np.array(top_lines, dtype=float)
|
241 |
+
gaps = bottom_lines - top_lines
|
242 |
+
target_gap = np.mean(gaps)
|
243 |
+
scales = target_gap / gaps
|
244 |
+
|
245 |
+
bottom_lines *= scales
|
246 |
+
top_lines *= scales
|
247 |
+
words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)]
|
248 |
+
|
249 |
+
highest = np.max(bottom_lines)
|
250 |
+
offsets = highest - bottom_lines
|
251 |
+
height = np.max(offsets + [word.shape[0] for word in words])
|
252 |
+
|
253 |
+
result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words])))
|
254 |
+
|
255 |
+
x_pos = 0
|
256 |
+
for bottom_line, word in zip(bottom_lines, words):
|
257 |
+
offset = int(highest - bottom_line)
|
258 |
+
|
259 |
+
result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word
|
260 |
+
|
261 |
+
x_pos += word.shape[1] + gap_width
|
262 |
+
|
263 |
+
return result
|
264 |
+
|
265 |
+
@torch.no_grad()
|
266 |
+
def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False):
|
267 |
+
if not isinstance(path, Path):
|
268 |
+
path = Path(path)
|
269 |
+
|
270 |
+
path.mkdir(exist_ok=True, parents=True)
|
271 |
+
|
272 |
+
appendix = f"{split}" if not long_tail_only else f"{split}_lt"
|
273 |
+
|
274 |
+
real_base = path / f'real_{appendix}'
|
275 |
+
fake_base = path / model_tag / f'fake_{appendix}'
|
276 |
+
|
277 |
+
if real_base.exists() and not fake_only:
|
278 |
+
shutil.rmtree(real_base)
|
279 |
+
|
280 |
+
if fake_base.exists():
|
281 |
+
shutil.rmtree(fake_base)
|
282 |
+
|
283 |
+
real_base.mkdir(exist_ok=True)
|
284 |
+
fake_base.mkdir(exist_ok=True, parents=True)
|
285 |
+
|
286 |
+
print('Saving images...')
|
287 |
+
|
288 |
+
print(' Saving images on {}'.format(str(real_base)))
|
289 |
+
print(' Saving images on {}'.format(str(fake_base)))
|
290 |
+
|
291 |
+
long_tail_chars = get_long_tail_chars()
|
292 |
+
counter = 0
|
293 |
+
ann = defaultdict(lambda: {})
|
294 |
+
start_time = time.time()
|
295 |
+
for step, data in enumerate(loader):
|
296 |
+
style_images = data['simg'].to(self.model.args.device)
|
297 |
+
|
298 |
+
texts = [l.decode('utf-8') for l in data['label']]
|
299 |
+
texts = [t.encode('utf-8') for t in texts]
|
300 |
+
eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts)
|
301 |
+
eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1)
|
302 |
+
|
303 |
+
vis_style = np.vstack(style_images[0].detach().cpu().numpy())
|
304 |
+
vis_style = ((vis_style + 1) / 2) * 255
|
305 |
+
|
306 |
+
fakes = self.model.netG.evaluate(style_images, eval_text_encode)
|
307 |
+
fake_images = torch.cat(fakes, 1).detach().cpu().numpy()
|
308 |
+
real_images = data['img'].detach().cpu().numpy()
|
309 |
+
writer_ids = data['wcl'].int().tolist()
|
310 |
+
|
311 |
+
for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])):
|
312 |
+
lb = lb.decode()
|
313 |
+
ann[f"{wid:03d}"][f'{img_id:05d}'] = lb
|
314 |
+
img_id = f'{img_id:05d}.png'
|
315 |
+
|
316 |
+
is_long_tail = any(c in long_tail_chars for c in lb)
|
317 |
+
|
318 |
+
if long_tail_only and not is_long_tail:
|
319 |
+
continue
|
320 |
+
|
321 |
+
fake_img_path = fake_base / f"{wid:03d}" / img_id
|
322 |
+
fake_img_path.parent.mkdir(exist_ok=True, parents=True)
|
323 |
+
cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2))
|
324 |
+
|
325 |
+
if not fake_only:
|
326 |
+
real_img_path = real_base / f"{wid:03d}" / img_id
|
327 |
+
real_img_path.parent.mkdir(exist_ok=True, parents=True)
|
328 |
+
cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2))
|
329 |
+
|
330 |
+
counter += 1
|
331 |
+
|
332 |
+
eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1)
|
333 |
+
eta = str(timedelta(seconds=eta))
|
334 |
+
if step % 100 == 0:
|
335 |
+
print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}')
|
336 |
+
|
337 |
+
with open(path / 'ann.json', 'w') as f:
|
338 |
+
json.dump(ann, f)
|
models/BigGAN_layers.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Layers
|
2 |
+
This file contains various layers for the BigGAN models.
|
3 |
+
'''
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import init
|
8 |
+
import torch.optim as optim
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn import Parameter as P
|
11 |
+
|
12 |
+
from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
|
13 |
+
|
14 |
+
# Projection of x onto y
|
15 |
+
def proj(x, y):
|
16 |
+
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
17 |
+
|
18 |
+
|
19 |
+
# Orthogonalize x wrt list of vectors ys
|
20 |
+
def gram_schmidt(x, ys):
|
21 |
+
for y in ys:
|
22 |
+
x = x - proj(x, y)
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
# Apply num_itrs steps of the power method to estimate top N singular values.
|
27 |
+
def power_iteration(W, u_, update=True, eps=1e-12):
|
28 |
+
# Lists holding singular vectors and values
|
29 |
+
us, vs, svs = [], [], []
|
30 |
+
for i, u in enumerate(u_):
|
31 |
+
# Run one step of the power iteration
|
32 |
+
with torch.no_grad():
|
33 |
+
v = torch.matmul(u, W)
|
34 |
+
# Run Gram-Schmidt to subtract components of all other singular vectors
|
35 |
+
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
36 |
+
# Add to the list
|
37 |
+
vs += [v]
|
38 |
+
# Update the other singular vector
|
39 |
+
u = torch.matmul(v, W.t())
|
40 |
+
# Run Gram-Schmidt to subtract components of all other singular vectors
|
41 |
+
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
42 |
+
# Add to the list
|
43 |
+
us += [u]
|
44 |
+
if update:
|
45 |
+
u_[i][:] = u
|
46 |
+
# Compute this singular value and add it to the list
|
47 |
+
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
48 |
+
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
49 |
+
return svs, us, vs
|
50 |
+
|
51 |
+
|
52 |
+
# Convenience passthrough function
|
53 |
+
class identity(nn.Module):
|
54 |
+
def forward(self, input):
|
55 |
+
return input
|
56 |
+
|
57 |
+
|
58 |
+
# Spectral normalization base class
|
59 |
+
class SN(object):
|
60 |
+
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
61 |
+
# Number of power iterations per step
|
62 |
+
self.num_itrs = num_itrs
|
63 |
+
# Number of singular values
|
64 |
+
self.num_svs = num_svs
|
65 |
+
# Transposed?
|
66 |
+
self.transpose = transpose
|
67 |
+
# Epsilon value for avoiding divide-by-0
|
68 |
+
self.eps = eps
|
69 |
+
# Register a singular vector for each sv
|
70 |
+
for i in range(self.num_svs):
|
71 |
+
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
72 |
+
self.register_buffer('sv%d' % i, torch.ones(1))
|
73 |
+
|
74 |
+
# Singular vectors (u side)
|
75 |
+
@property
|
76 |
+
def u(self):
|
77 |
+
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
78 |
+
|
79 |
+
# Singular values;
|
80 |
+
# note that these buffers are just for logging and are not used in training.
|
81 |
+
@property
|
82 |
+
def sv(self):
|
83 |
+
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
84 |
+
|
85 |
+
# Compute the spectrally-normalized weight
|
86 |
+
def W_(self):
|
87 |
+
W_mat = self.weight.view(self.weight.size(0), -1)
|
88 |
+
if self.transpose:
|
89 |
+
W_mat = W_mat.t()
|
90 |
+
# Apply num_itrs power iterations
|
91 |
+
for _ in range(self.num_itrs):
|
92 |
+
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
93 |
+
# Update the svs
|
94 |
+
if self.training:
|
95 |
+
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
96 |
+
for i, sv in enumerate(svs):
|
97 |
+
self.sv[i][:] = sv
|
98 |
+
return self.weight / svs[0]
|
99 |
+
|
100 |
+
|
101 |
+
# 2D Conv layer with spectral norm
|
102 |
+
class SNConv2d(nn.Conv2d, SN):
|
103 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
104 |
+
padding=0, dilation=1, groups=1, bias=True,
|
105 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
106 |
+
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
107 |
+
padding, dilation, groups, bias)
|
108 |
+
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
112 |
+
self.padding, self.dilation, self.groups)
|
113 |
+
|
114 |
+
|
115 |
+
# Linear layer with spectral norm
|
116 |
+
class SNLinear(nn.Linear, SN):
|
117 |
+
def __init__(self, in_features, out_features, bias=True,
|
118 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
119 |
+
nn.Linear.__init__(self, in_features, out_features, bias)
|
120 |
+
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
return F.linear(x, self.W_(), self.bias)
|
124 |
+
|
125 |
+
|
126 |
+
# Embedding layer with spectral norm
|
127 |
+
# We use num_embeddings as the dim instead of embedding_dim here
|
128 |
+
# for convenience sake
|
129 |
+
class SNEmbedding(nn.Embedding, SN):
|
130 |
+
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
131 |
+
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
132 |
+
sparse=False, _weight=None,
|
133 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
134 |
+
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
135 |
+
max_norm, norm_type, scale_grad_by_freq,
|
136 |
+
sparse, _weight)
|
137 |
+
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
return F.embedding(x, self.W_())
|
141 |
+
|
142 |
+
|
143 |
+
# A non-local block as used in SA-GAN
|
144 |
+
# Note that the implementation as described in the paper is largely incorrect;
|
145 |
+
# refer to the released code for the actual implementation.
|
146 |
+
class Attention(nn.Module):
|
147 |
+
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
148 |
+
super(Attention, self).__init__()
|
149 |
+
# Channel multiplier
|
150 |
+
self.ch = ch
|
151 |
+
self.which_conv = which_conv
|
152 |
+
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
153 |
+
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
154 |
+
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
155 |
+
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
156 |
+
# Learnable gain parameter
|
157 |
+
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
158 |
+
|
159 |
+
def forward(self, x, y=None):
|
160 |
+
# Apply convs
|
161 |
+
theta = self.theta(x)
|
162 |
+
phi = F.max_pool2d(self.phi(x), [2, 2])
|
163 |
+
g = F.max_pool2d(self.g(x), [2, 2])
|
164 |
+
# Perform reshapes
|
165 |
+
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
|
166 |
+
try:
|
167 |
+
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
|
168 |
+
except:
|
169 |
+
print(phi.shape)
|
170 |
+
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
|
171 |
+
# Matmul and softmax to get attention maps
|
172 |
+
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
173 |
+
# Attention map times g path
|
174 |
+
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
175 |
+
return self.gamma * o + x
|
176 |
+
|
177 |
+
|
178 |
+
# Fused batchnorm op
|
179 |
+
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
180 |
+
# Apply scale and shift--if gain and bias are provided, fuse them here
|
181 |
+
# Prepare scale
|
182 |
+
scale = torch.rsqrt(var + eps)
|
183 |
+
# If a gain is provided, use it
|
184 |
+
if gain is not None:
|
185 |
+
scale = scale * gain
|
186 |
+
# Prepare shift
|
187 |
+
shift = mean * scale
|
188 |
+
# If bias is provided, use it
|
189 |
+
if bias is not None:
|
190 |
+
shift = shift - bias
|
191 |
+
return x * scale - shift
|
192 |
+
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
193 |
+
|
194 |
+
|
195 |
+
# Manual BN
|
196 |
+
# Calculate means and variances using mean-of-squares minus mean-squared
|
197 |
+
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
198 |
+
# Cast x to float32 if necessary
|
199 |
+
float_x = x.float()
|
200 |
+
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
201 |
+
# Mean of x
|
202 |
+
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
203 |
+
# Mean of x squared
|
204 |
+
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
205 |
+
# Calculate variance as mean of squared minus mean squared.
|
206 |
+
var = (m2 - m ** 2)
|
207 |
+
# Cast back to float 16 if necessary
|
208 |
+
var = var.type(x.type())
|
209 |
+
m = m.type(x.type())
|
210 |
+
# Return mean and variance for updating stored mean/var if requested
|
211 |
+
if return_mean_var:
|
212 |
+
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
213 |
+
else:
|
214 |
+
return fused_bn(x, m, var, gain, bias, eps)
|
215 |
+
|
216 |
+
|
217 |
+
# My batchnorm, supports standing stats
|
218 |
+
class myBN(nn.Module):
|
219 |
+
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
220 |
+
super(myBN, self).__init__()
|
221 |
+
# momentum for updating running stats
|
222 |
+
self.momentum = momentum
|
223 |
+
# epsilon to avoid dividing by 0
|
224 |
+
self.eps = eps
|
225 |
+
# Momentum
|
226 |
+
self.momentum = momentum
|
227 |
+
# Register buffers
|
228 |
+
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
229 |
+
self.register_buffer('stored_var', torch.ones(num_channels))
|
230 |
+
self.register_buffer('accumulation_counter', torch.zeros(1))
|
231 |
+
# Accumulate running means and vars
|
232 |
+
self.accumulate_standing = False
|
233 |
+
|
234 |
+
# reset standing stats
|
235 |
+
def reset_stats(self):
|
236 |
+
self.stored_mean[:] = 0
|
237 |
+
self.stored_var[:] = 0
|
238 |
+
self.accumulation_counter[:] = 0
|
239 |
+
|
240 |
+
def forward(self, x, gain, bias):
|
241 |
+
if self.training:
|
242 |
+
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
243 |
+
# If accumulating standing stats, increment them
|
244 |
+
if self.accumulate_standing:
|
245 |
+
self.stored_mean[:] = self.stored_mean + mean.data
|
246 |
+
self.stored_var[:] = self.stored_var + var.data
|
247 |
+
self.accumulation_counter += 1.0
|
248 |
+
# If not accumulating standing stats, take running averages
|
249 |
+
else:
|
250 |
+
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
251 |
+
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
252 |
+
return out
|
253 |
+
# If not in training mode, use the stored statistics
|
254 |
+
else:
|
255 |
+
mean = self.stored_mean.view(1, -1, 1, 1)
|
256 |
+
var = self.stored_var.view(1, -1, 1, 1)
|
257 |
+
# If using standing stats, divide them by the accumulation counter
|
258 |
+
if self.accumulate_standing:
|
259 |
+
mean = mean / self.accumulation_counter
|
260 |
+
var = var / self.accumulation_counter
|
261 |
+
return fused_bn(x, mean, var, gain, bias, self.eps)
|
262 |
+
|
263 |
+
|
264 |
+
# Simple function to handle groupnorm norm stylization
|
265 |
+
def groupnorm(x, norm_style):
|
266 |
+
# If number of channels specified in norm_style:
|
267 |
+
if 'ch' in norm_style:
|
268 |
+
ch = int(norm_style.split('_')[-1])
|
269 |
+
groups = max(int(x.shape[1]) // ch, 1)
|
270 |
+
# If number of groups specified in norm style
|
271 |
+
elif 'grp' in norm_style:
|
272 |
+
groups = int(norm_style.split('_')[-1])
|
273 |
+
# If neither, default to groups = 16
|
274 |
+
else:
|
275 |
+
groups = 16
|
276 |
+
return F.group_norm(x, groups)
|
277 |
+
|
278 |
+
|
279 |
+
# Class-conditional bn
|
280 |
+
# output size is the number of channels, input size is for the linear layers
|
281 |
+
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
282 |
+
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
283 |
+
# if you want to make this more readable/usable).
|
284 |
+
class ccbn(nn.Module):
|
285 |
+
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
286 |
+
cross_replica=False, mybn=False, norm_style='bn', ):
|
287 |
+
super(ccbn, self).__init__()
|
288 |
+
self.output_size, self.input_size = output_size, input_size
|
289 |
+
# Prepare gain and bias layers
|
290 |
+
self.gain = which_linear(input_size, output_size)
|
291 |
+
self.bias = which_linear(input_size, output_size)
|
292 |
+
# epsilon to avoid dividing by 0
|
293 |
+
self.eps = eps
|
294 |
+
# Momentum
|
295 |
+
self.momentum = momentum
|
296 |
+
# Use cross-replica batchnorm?
|
297 |
+
self.cross_replica = cross_replica
|
298 |
+
# Use my batchnorm?
|
299 |
+
self.mybn = mybn
|
300 |
+
# Norm style?
|
301 |
+
self.norm_style = norm_style
|
302 |
+
|
303 |
+
if self.cross_replica:
|
304 |
+
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
305 |
+
elif self.mybn:
|
306 |
+
self.bn = myBN(output_size, self.eps, self.momentum)
|
307 |
+
elif self.norm_style in ['bn', 'in']:
|
308 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
309 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
310 |
+
|
311 |
+
def forward(self, x, y):
|
312 |
+
# Calculate class-conditional gains and biases
|
313 |
+
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
314 |
+
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
315 |
+
# If using my batchnorm
|
316 |
+
if self.mybn or self.cross_replica:
|
317 |
+
return self.bn(x, gain=gain, bias=bias)
|
318 |
+
# else:
|
319 |
+
else:
|
320 |
+
if self.norm_style == 'bn':
|
321 |
+
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
322 |
+
self.training, 0.1, self.eps)
|
323 |
+
elif self.norm_style == 'in':
|
324 |
+
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
325 |
+
self.training, 0.1, self.eps)
|
326 |
+
elif self.norm_style == 'gn':
|
327 |
+
out = groupnorm(x, self.normstyle)
|
328 |
+
elif self.norm_style == 'nonorm':
|
329 |
+
out = x
|
330 |
+
return out * gain + bias
|
331 |
+
|
332 |
+
def extra_repr(self):
|
333 |
+
s = 'out: {output_size}, in: {input_size},'
|
334 |
+
s += ' cross_replica={cross_replica}'
|
335 |
+
return s.format(**self.__dict__)
|
336 |
+
|
337 |
+
|
338 |
+
# Normal, non-class-conditional BN
|
339 |
+
class bn(nn.Module):
|
340 |
+
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
341 |
+
cross_replica=False, mybn=False):
|
342 |
+
super(bn, self).__init__()
|
343 |
+
self.output_size = output_size
|
344 |
+
# Prepare gain and bias layers
|
345 |
+
self.gain = P(torch.ones(output_size), requires_grad=True)
|
346 |
+
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
347 |
+
# epsilon to avoid dividing by 0
|
348 |
+
self.eps = eps
|
349 |
+
# Momentum
|
350 |
+
self.momentum = momentum
|
351 |
+
# Use cross-replica batchnorm?
|
352 |
+
self.cross_replica = cross_replica
|
353 |
+
# Use my batchnorm?
|
354 |
+
self.mybn = mybn
|
355 |
+
|
356 |
+
if self.cross_replica:
|
357 |
+
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
358 |
+
elif mybn:
|
359 |
+
self.bn = myBN(output_size, self.eps, self.momentum)
|
360 |
+
# Register buffers if neither of the above
|
361 |
+
else:
|
362 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
363 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
364 |
+
|
365 |
+
def forward(self, x, y=None):
|
366 |
+
if self.cross_replica or self.mybn:
|
367 |
+
gain = self.gain.view(1, -1, 1, 1)
|
368 |
+
bias = self.bias.view(1, -1, 1, 1)
|
369 |
+
return self.bn(x, gain=gain, bias=bias)
|
370 |
+
else:
|
371 |
+
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
372 |
+
self.bias, self.training, self.momentum, self.eps)
|
373 |
+
|
374 |
+
|
375 |
+
# Generator blocks
|
376 |
+
# Note that this class assumes the kernel size and padding (and any other
|
377 |
+
# settings) have been selected in the main generator module and passed in
|
378 |
+
# through the which_conv arg. Similar rules apply with which_bn (the input
|
379 |
+
# size [which is actually the number of channels of the conditional info] must
|
380 |
+
# be preselected)
|
381 |
+
class GBlock(nn.Module):
|
382 |
+
def __init__(self, in_channels, out_channels,
|
383 |
+
which_conv1=nn.Conv2d, which_conv2=nn.Conv2d, which_bn=bn, activation=None,
|
384 |
+
upsample=None):
|
385 |
+
super(GBlock, self).__init__()
|
386 |
+
|
387 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
388 |
+
self.which_conv1, self.which_conv2, self.which_bn = which_conv1, which_conv2, which_bn
|
389 |
+
self.activation = activation
|
390 |
+
self.upsample = upsample
|
391 |
+
# Conv layers
|
392 |
+
self.conv1 = self.which_conv1(self.in_channels, self.out_channels)
|
393 |
+
self.conv2 = self.which_conv2(self.out_channels, self.out_channels)
|
394 |
+
self.learnable_sc = in_channels != out_channels or upsample
|
395 |
+
if self.learnable_sc:
|
396 |
+
self.conv_sc = self.which_conv1(in_channels, out_channels,
|
397 |
+
kernel_size=1, padding=0)
|
398 |
+
# Batchnorm layers
|
399 |
+
self.bn1 = self.which_bn(in_channels)
|
400 |
+
self.bn2 = self.which_bn(out_channels)
|
401 |
+
# upsample layers
|
402 |
+
self.upsample = upsample
|
403 |
+
|
404 |
+
def forward(self, x, y):
|
405 |
+
h = self.activation(self.bn1(x, y))
|
406 |
+
# h = self.activation(x)
|
407 |
+
# h=x
|
408 |
+
if self.upsample:
|
409 |
+
h = self.upsample(h)
|
410 |
+
x = self.upsample(x)
|
411 |
+
h = self.conv1(h)
|
412 |
+
h = self.activation(self.bn2(h, y))
|
413 |
+
# h = self.activation(h)
|
414 |
+
h = self.conv2(h)
|
415 |
+
if self.learnable_sc:
|
416 |
+
x = self.conv_sc(x)
|
417 |
+
return h + x
|
418 |
+
|
419 |
+
|
420 |
+
# Residual block for the discriminator
|
421 |
+
class DBlock(nn.Module):
|
422 |
+
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
423 |
+
preactivation=False, activation=None, downsample=None, ):
|
424 |
+
super(DBlock, self).__init__()
|
425 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
426 |
+
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
427 |
+
self.hidden_channels = self.out_channels if wide else self.in_channels
|
428 |
+
self.which_conv = which_conv
|
429 |
+
self.preactivation = preactivation
|
430 |
+
self.activation = activation
|
431 |
+
self.downsample = downsample
|
432 |
+
|
433 |
+
# Conv layers
|
434 |
+
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
435 |
+
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
436 |
+
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
437 |
+
if self.learnable_sc:
|
438 |
+
self.conv_sc = self.which_conv(in_channels, out_channels,
|
439 |
+
kernel_size=1, padding=0)
|
440 |
+
|
441 |
+
def shortcut(self, x):
|
442 |
+
if self.preactivation:
|
443 |
+
if self.learnable_sc:
|
444 |
+
x = self.conv_sc(x)
|
445 |
+
if self.downsample:
|
446 |
+
x = self.downsample(x)
|
447 |
+
else:
|
448 |
+
if self.downsample:
|
449 |
+
x = self.downsample(x)
|
450 |
+
if self.learnable_sc:
|
451 |
+
x = self.conv_sc(x)
|
452 |
+
return x
|
453 |
+
|
454 |
+
def forward(self, x):
|
455 |
+
if self.preactivation:
|
456 |
+
# h = self.activation(x) # NOT TODAY SATAN
|
457 |
+
# Andy's note: This line *must* be an out-of-place ReLU or it
|
458 |
+
# will negatively affect the shortcut connection.
|
459 |
+
h = F.relu(x)
|
460 |
+
else:
|
461 |
+
h = x
|
462 |
+
h = self.conv1(h)
|
463 |
+
h = self.conv2(self.activation(h))
|
464 |
+
if self.downsample:
|
465 |
+
h = self.downsample(h)
|
466 |
+
|
467 |
+
return h + self.shortcut(x)
|
468 |
+
|
469 |
+
# dogball
|
models/BigGAN_networks.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
import functools
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
|
12 |
+
from util.augmentations import ProgressiveWordCrop, CycleWordCrop, StaticWordCrop, RandomWordCrop
|
13 |
+
from . import BigGAN_layers as layers
|
14 |
+
from .networks import init_weights
|
15 |
+
import torchvision
|
16 |
+
# Attention is passed in in the format '32_64' to mean applying an attention
|
17 |
+
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
|
18 |
+
|
19 |
+
from models.blocks import Conv2dBlock, ResBlocks
|
20 |
+
|
21 |
+
|
22 |
+
# Discriminator architecture, same paradigm as G's above
|
23 |
+
def D_arch(ch=64, attention='64', input_nc=3, ksize='333333', dilation='111111'):
|
24 |
+
arch = {}
|
25 |
+
arch[256] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
|
26 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
27 |
+
'downsample': [True] * 6 + [False],
|
28 |
+
'resolution': [128, 64, 32, 16, 8, 4, 4],
|
29 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
30 |
+
for i in range(2, 8)}}
|
31 |
+
arch[128] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
32 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
33 |
+
'downsample': [True] * 5 + [False],
|
34 |
+
'resolution': [64, 32, 16, 8, 4, 4],
|
35 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
36 |
+
for i in range(2, 8)}}
|
37 |
+
arch[64] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
|
38 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
|
39 |
+
'downsample': [True] * 4 + [False],
|
40 |
+
'resolution': [32, 16, 8, 4, 4],
|
41 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
42 |
+
for i in range(2, 7)}}
|
43 |
+
arch[63] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
|
44 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
|
45 |
+
'downsample': [True] * 4 + [False],
|
46 |
+
'resolution': [32, 16, 8, 4, 4],
|
47 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
48 |
+
for i in range(2, 7)}}
|
49 |
+
arch[32] = {'in_channels': [input_nc] + [item * ch for item in [4, 4, 4]],
|
50 |
+
'out_channels': [item * ch for item in [4, 4, 4, 4]],
|
51 |
+
'downsample': [True, True, False, False],
|
52 |
+
'resolution': [16, 16, 16, 16],
|
53 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
54 |
+
for i in range(2, 6)}}
|
55 |
+
arch[129] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
|
56 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
57 |
+
'downsample': [True] * 6 + [False],
|
58 |
+
'resolution': [128, 64, 32, 16, 8, 4, 4],
|
59 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
60 |
+
for i in range(2, 8)}}
|
61 |
+
arch[33] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
62 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
63 |
+
'downsample': [True] * 5 + [False],
|
64 |
+
'resolution': [64, 32, 16, 8, 4, 4],
|
65 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
66 |
+
for i in range(2, 10)}}
|
67 |
+
arch[31] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
68 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
69 |
+
'downsample': [True] * 5 + [False],
|
70 |
+
'resolution': [64, 32, 16, 8, 4, 4],
|
71 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
72 |
+
for i in range(2, 10)}}
|
73 |
+
arch[16] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
|
74 |
+
'out_channels': [item * ch for item in [1, 8, 16, 16]],
|
75 |
+
'downsample': [True] * 3 + [False],
|
76 |
+
'resolution': [16, 8, 4, 4],
|
77 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
78 |
+
for i in range(2, 5)}}
|
79 |
+
|
80 |
+
arch[17] = {'in_channels': [input_nc] + [ch * item for item in [1, 4]],
|
81 |
+
'out_channels': [item * ch for item in [1, 4, 8]],
|
82 |
+
'downsample': [True] * 3,
|
83 |
+
'resolution': [16, 8, 4],
|
84 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
85 |
+
for i in range(2, 5)}}
|
86 |
+
|
87 |
+
|
88 |
+
arch[20] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
|
89 |
+
'out_channels': [item * ch for item in [1, 8, 16, 16]],
|
90 |
+
'downsample': [True] * 3 + [False],
|
91 |
+
'resolution': [16, 8, 4, 4],
|
92 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
93 |
+
for i in range(2, 5)}}
|
94 |
+
return arch
|
95 |
+
|
96 |
+
|
97 |
+
class Discriminator(nn.Module):
|
98 |
+
|
99 |
+
def __init__(self, resolution, D_ch=64, D_wide=True, D_kernel_size=3, D_attn='64',
|
100 |
+
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
101 |
+
SN_eps=1e-8, output_dim=1, D_mixed_precision=False, D_fp16=False,
|
102 |
+
D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, crop_size: list = None, **kwargs):
|
103 |
+
|
104 |
+
super(Discriminator, self).__init__()
|
105 |
+
self.crop = crop_size is not None and len(crop_size) > 0
|
106 |
+
|
107 |
+
use_padding = False
|
108 |
+
|
109 |
+
if self.crop:
|
110 |
+
w_crop = StaticWordCrop(crop_size[0], use_padding=use_padding) if len(crop_size) == 1 else RandomWordCrop(crop_size[0], crop_size[1], use_padding=use_padding)
|
111 |
+
|
112 |
+
self.augmenter = w_crop
|
113 |
+
|
114 |
+
self.name = 'D'
|
115 |
+
# gpu_ids
|
116 |
+
self.gpu_ids = gpu_ids
|
117 |
+
# one_hot representation
|
118 |
+
self.one_hot = one_hot
|
119 |
+
# Width multiplier
|
120 |
+
self.ch = D_ch
|
121 |
+
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
122 |
+
self.D_wide = D_wide
|
123 |
+
# Resolution
|
124 |
+
self.resolution = resolution
|
125 |
+
# Kernel size
|
126 |
+
self.kernel_size = D_kernel_size
|
127 |
+
# Attention?
|
128 |
+
self.attention = D_attn
|
129 |
+
# Activation
|
130 |
+
self.activation = D_activation
|
131 |
+
# Initialization style
|
132 |
+
self.init = D_init
|
133 |
+
# Parameterization style
|
134 |
+
self.D_param = D_param
|
135 |
+
# Epsilon for Spectral Norm?
|
136 |
+
self.SN_eps = SN_eps
|
137 |
+
# Fp16?
|
138 |
+
self.fp16 = D_fp16
|
139 |
+
# Architecture
|
140 |
+
self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
|
141 |
+
|
142 |
+
# Which convs, batchnorms, and linear layers to use
|
143 |
+
# No option to turn off SN in D right now
|
144 |
+
if self.D_param == 'SN':
|
145 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
146 |
+
kernel_size=3, padding=1,
|
147 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
148 |
+
eps=self.SN_eps)
|
149 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
150 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
151 |
+
eps=self.SN_eps)
|
152 |
+
self.which_embedding = functools.partial(layers.SNEmbedding,
|
153 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
154 |
+
eps=self.SN_eps)
|
155 |
+
if bn_linear=='SN':
|
156 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
157 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
158 |
+
eps=self.SN_eps)
|
159 |
+
else:
|
160 |
+
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
161 |
+
self.which_linear = nn.Linear
|
162 |
+
# We use a non-spectral-normed embedding here regardless;
|
163 |
+
# For some reason applying SN to G's embedding seems to randomly cripple G
|
164 |
+
self.which_embedding = nn.Embedding
|
165 |
+
if one_hot:
|
166 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
167 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
168 |
+
eps=self.SN_eps)
|
169 |
+
# Prepare model
|
170 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
171 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
172 |
+
self.blocks = []
|
173 |
+
for index in range(len(self.arch['out_channels'])):
|
174 |
+
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
175 |
+
out_channels=self.arch['out_channels'][index],
|
176 |
+
which_conv=self.which_conv,
|
177 |
+
wide=self.D_wide,
|
178 |
+
activation=self.activation,
|
179 |
+
preactivation=(index > 0),
|
180 |
+
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
181 |
+
# If attention on this block, attach it to the end
|
182 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
183 |
+
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
184 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
185 |
+
self.which_conv)]
|
186 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
187 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
188 |
+
# Linear output layer. The output dimension is typically 1, but may be
|
189 |
+
# larger if we're e.g. turning this into a VAE with an inference output
|
190 |
+
self.dropout = torch.nn.Dropout(p=0.5)
|
191 |
+
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
192 |
+
|
193 |
+
# Initialize weights
|
194 |
+
if not skip_init:
|
195 |
+
self = init_weights(self, D_init)
|
196 |
+
|
197 |
+
def update_parameters(self, epoch: int):
|
198 |
+
if self.crop:
|
199 |
+
self.augmenter.update(epoch)
|
200 |
+
|
201 |
+
def forward(self, x, y=None, **kwargs):
|
202 |
+
# Stick x into h for cleaner for loops without flow control
|
203 |
+
if self.crop and random.uniform(0.0, 1.0) < 0.33:
|
204 |
+
x = self.augmenter(x)
|
205 |
+
|
206 |
+
#imgs = [np.squeeze((img.detach().cpu().numpy() + 1.0) / 2.0) for img in x]
|
207 |
+
#imgs = (np.vstack(imgs) * 255.0).astype(np.uint8)
|
208 |
+
#cv2.imwrite(f"saved_images/debug/{random.randint(0, 1000)}.jpg", imgs)
|
209 |
+
|
210 |
+
h = x
|
211 |
+
# Loop over blocks
|
212 |
+
for index, blocklist in enumerate(self.blocks):
|
213 |
+
for block in blocklist:
|
214 |
+
h = block(h)
|
215 |
+
|
216 |
+
# Apply global sum pooling as in SN-GAN
|
217 |
+
h = torch.sum(self.activation(h), [2, 3])
|
218 |
+
out = self.linear(h)
|
219 |
+
|
220 |
+
return out
|
221 |
+
|
222 |
+
def return_features(self, x, y=None):
|
223 |
+
# Stick x into h for cleaner for loops without flow control
|
224 |
+
h = x
|
225 |
+
block_output = []
|
226 |
+
# Loop over blocks
|
227 |
+
for index, blocklist in enumerate(self.blocks):
|
228 |
+
for block in blocklist:
|
229 |
+
h = block(h)
|
230 |
+
block_output.append(h)
|
231 |
+
# Apply global sum pooling as in SN-GAN
|
232 |
+
# h = torch.sum(self.activation(h), [2, 3])
|
233 |
+
return block_output
|
234 |
+
|
235 |
+
|
236 |
+
class WDiscriminator(nn.Module):
|
237 |
+
|
238 |
+
def __init__(self, resolution, n_classes, output_dim, D_ch=64, D_wide=True, D_kernel_size=3, D_attn='64',
|
239 |
+
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
240 |
+
SN_eps=1e-8, D_mixed_precision=False, D_fp16=False,
|
241 |
+
D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False):
|
242 |
+
super(WDiscriminator, self).__init__()
|
243 |
+
|
244 |
+
self.name = 'D'
|
245 |
+
# gpu_ids
|
246 |
+
self.gpu_ids = gpu_ids
|
247 |
+
# one_hot representation
|
248 |
+
self.one_hot = one_hot
|
249 |
+
# Width multiplier
|
250 |
+
self.ch = D_ch
|
251 |
+
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
252 |
+
self.D_wide = D_wide
|
253 |
+
# Resolution
|
254 |
+
self.resolution = resolution
|
255 |
+
# Kernel size
|
256 |
+
self.kernel_size = D_kernel_size
|
257 |
+
# Attention?
|
258 |
+
self.attention = D_attn
|
259 |
+
# Number of classes
|
260 |
+
self.n_classes = n_classes
|
261 |
+
# Activation
|
262 |
+
self.activation = D_activation
|
263 |
+
# Initialization style
|
264 |
+
self.init = D_init
|
265 |
+
# Parameterization style
|
266 |
+
self.D_param = D_param
|
267 |
+
# Epsilon for Spectral Norm?
|
268 |
+
self.SN_eps = SN_eps
|
269 |
+
# Fp16?
|
270 |
+
self.fp16 = D_fp16
|
271 |
+
# Architecture
|
272 |
+
self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
|
273 |
+
|
274 |
+
# Which convs, batchnorms, and linear layers to use
|
275 |
+
# No option to turn off SN in D right now
|
276 |
+
if self.D_param == 'SN':
|
277 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
278 |
+
kernel_size=3, padding=1,
|
279 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
280 |
+
eps=self.SN_eps)
|
281 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
282 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
283 |
+
eps=self.SN_eps)
|
284 |
+
self.which_embedding = functools.partial(layers.SNEmbedding,
|
285 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
286 |
+
eps=self.SN_eps)
|
287 |
+
if bn_linear == 'SN':
|
288 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
289 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
290 |
+
eps=self.SN_eps)
|
291 |
+
else:
|
292 |
+
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
293 |
+
self.which_linear = nn.Linear
|
294 |
+
# We use a non-spectral-normed embedding here regardless;
|
295 |
+
# For some reason applying SN to G's embedding seems to randomly cripple G
|
296 |
+
self.which_embedding = nn.Embedding
|
297 |
+
if one_hot:
|
298 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
299 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
300 |
+
eps=self.SN_eps)
|
301 |
+
# Prepare model
|
302 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
303 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
304 |
+
self.blocks = []
|
305 |
+
for index in range(len(self.arch['out_channels'])):
|
306 |
+
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
307 |
+
out_channels=self.arch['out_channels'][index],
|
308 |
+
which_conv=self.which_conv,
|
309 |
+
wide=self.D_wide,
|
310 |
+
activation=self.activation,
|
311 |
+
preactivation=(index > 0),
|
312 |
+
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
313 |
+
# If attention on this block, attach it to the end
|
314 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
315 |
+
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
316 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
317 |
+
self.which_conv)]
|
318 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
319 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
320 |
+
# Linear output layer. The output dimension is typically 1, but may be
|
321 |
+
# larger if we're e.g. turning this into a VAE with an inference output
|
322 |
+
self.dropout = torch.nn.Dropout(p=0.5)
|
323 |
+
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
324 |
+
# Embedding for projection discrimination
|
325 |
+
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
|
326 |
+
self.cross_entropy = nn.CrossEntropyLoss()
|
327 |
+
# Initialize weights
|
328 |
+
if not skip_init:
|
329 |
+
self = init_weights(self, D_init)
|
330 |
+
|
331 |
+
def update_parameters(self, epoch: int):
|
332 |
+
pass
|
333 |
+
|
334 |
+
def forward(self, x, y=None, **kwargs):
|
335 |
+
# Stick x into h for cleaner for loops without flow control
|
336 |
+
h = x
|
337 |
+
# Loop over blocks
|
338 |
+
for index, blocklist in enumerate(self.blocks):
|
339 |
+
for block in blocklist:
|
340 |
+
h = block(h)
|
341 |
+
# Apply global sum pooling as in SN-GAN
|
342 |
+
h = torch.sum(self.activation(h), [2, 3])
|
343 |
+
|
344 |
+
# Get initial class-unconditional output
|
345 |
+
out = self.linear(h)
|
346 |
+
# Get projection of final featureset onto class vectors and add to evidence
|
347 |
+
#if y is not None:
|
348 |
+
loss = self.cross_entropy(out, y.long())
|
349 |
+
return loss
|
350 |
+
|
351 |
+
def return_features(self, x, y=None):
|
352 |
+
# Stick x into h for cleaner for loops without flow control
|
353 |
+
h = x
|
354 |
+
block_output = []
|
355 |
+
# Loop over blocks
|
356 |
+
for index, blocklist in enumerate(self.blocks):
|
357 |
+
for block in blocklist:
|
358 |
+
h = block(h)
|
359 |
+
block_output.append(h)
|
360 |
+
# Apply global sum pooling as in SN-GAN
|
361 |
+
# h = torch.sum(self.activation(h), [2, 3])
|
362 |
+
return block_output
|
363 |
+
|
364 |
+
|
365 |
+
class Encoder(Discriminator):
|
366 |
+
def __init__(self, opt, output_dim, **kwargs):
|
367 |
+
super(Encoder, self).__init__(**vars(opt))
|
368 |
+
self.output_layer = nn.Sequential(self.activation,
|
369 |
+
nn.Conv2d(self.arch['out_channels'][-1], output_dim, kernel_size=(4,2), padding=0, stride=2))
|
370 |
+
|
371 |
+
def forward(self, x):
|
372 |
+
# Stick x into h for cleaner for loops without flow control
|
373 |
+
h = x
|
374 |
+
# Loop over blocks
|
375 |
+
for index, blocklist in enumerate(self.blocks):
|
376 |
+
for block in blocklist:
|
377 |
+
h = block(h)
|
378 |
+
out = self.output_layer(h)
|
379 |
+
return out
|
models/OCR_network.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .networks import *
|
3 |
+
|
4 |
+
|
5 |
+
class BidirectionalLSTM(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, nIn, nHidden, nOut):
|
8 |
+
super(BidirectionalLSTM, self).__init__()
|
9 |
+
|
10 |
+
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
11 |
+
self.embedding = nn.Linear(nHidden * 2, nOut)
|
12 |
+
|
13 |
+
|
14 |
+
def forward(self, input):
|
15 |
+
recurrent, _ = self.rnn(input)
|
16 |
+
T, b, h = recurrent.size()
|
17 |
+
t_rec = recurrent.view(T * b, h)
|
18 |
+
|
19 |
+
output = self.embedding(t_rec) # [T * b, nOut]
|
20 |
+
output = output.view(T, b, -1)
|
21 |
+
|
22 |
+
return output
|
23 |
+
|
24 |
+
|
25 |
+
class CRNN(nn.Module):
|
26 |
+
|
27 |
+
def __init__(self, args, leakyRelu=False):
|
28 |
+
super(CRNN, self).__init__()
|
29 |
+
self.args = args
|
30 |
+
self.name = 'OCR'
|
31 |
+
self.add_noise = False
|
32 |
+
self.noise_fac = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([0.2]))
|
33 |
+
#assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16'
|
34 |
+
|
35 |
+
ks = [3, 3, 3, 3, 3, 3, 2]
|
36 |
+
ps = [1, 1, 1, 1, 1, 1, 0]
|
37 |
+
ss = [1, 1, 1, 1, 1, 1, 1]
|
38 |
+
nm = [64, 128, 256, 256, 512, 512, 512]
|
39 |
+
|
40 |
+
cnn = nn.Sequential()
|
41 |
+
nh = 256
|
42 |
+
dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero
|
43 |
+
|
44 |
+
def convRelu(i, batchNormalization=False):
|
45 |
+
nIn = 1 if i == 0 else nm[i - 1]
|
46 |
+
nOut = nm[i]
|
47 |
+
cnn.add_module('conv{0}'.format(i),
|
48 |
+
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
49 |
+
if batchNormalization:
|
50 |
+
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
51 |
+
if leakyRelu:
|
52 |
+
cnn.add_module('relu{0}'.format(i),
|
53 |
+
nn.LeakyReLU(0.2, inplace=True))
|
54 |
+
else:
|
55 |
+
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
56 |
+
|
57 |
+
convRelu(0)
|
58 |
+
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
59 |
+
convRelu(1)
|
60 |
+
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
61 |
+
convRelu(2, True)
|
62 |
+
convRelu(3)
|
63 |
+
cnn.add_module('pooling{0}'.format(2),
|
64 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
65 |
+
convRelu(4, True)
|
66 |
+
if self.args.resolution==63:
|
67 |
+
cnn.add_module('pooling{0}'.format(3),
|
68 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
69 |
+
convRelu(5)
|
70 |
+
cnn.add_module('pooling{0}'.format(4),
|
71 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
72 |
+
convRelu(6, True) # 512x1x16
|
73 |
+
|
74 |
+
self.cnn = cnn
|
75 |
+
self.use_rnn = False
|
76 |
+
if self.use_rnn:
|
77 |
+
self.rnn = nn.Sequential(
|
78 |
+
BidirectionalLSTM(512, nh, nh),
|
79 |
+
BidirectionalLSTM(nh, nh, ))
|
80 |
+
else:
|
81 |
+
self.linear = nn.Linear(512, self.args.vocab_size)
|
82 |
+
|
83 |
+
# replace all nan/inf in gradients to zero
|
84 |
+
if dealwith_lossnone:
|
85 |
+
self.register_backward_hook(self.backward_hook)
|
86 |
+
|
87 |
+
self.device = torch.device('cuda:{}'.format(0))
|
88 |
+
self.init = 'N02'
|
89 |
+
# Initialize weights
|
90 |
+
|
91 |
+
self = init_weights(self, self.init)
|
92 |
+
|
93 |
+
def forward(self, input):
|
94 |
+
# conv features
|
95 |
+
if self.add_noise:
|
96 |
+
input = input + self.noise_fac.sample(input.size()).squeeze(-1).to(self.args.device)
|
97 |
+
conv = self.cnn(input)
|
98 |
+
b, c, h, w = conv.size()
|
99 |
+
if h!=1:
|
100 |
+
print('a')
|
101 |
+
assert h == 1, "the height of conv must be 1"
|
102 |
+
conv = conv.squeeze(2)
|
103 |
+
conv = conv.permute(2, 0, 1) # [w, b, c]
|
104 |
+
|
105 |
+
if self.use_rnn:
|
106 |
+
# rnn features
|
107 |
+
output = self.rnn(conv)
|
108 |
+
else:
|
109 |
+
output = self.linear(conv)
|
110 |
+
return output
|
111 |
+
|
112 |
+
def backward_hook(self, module, grad_input, grad_output):
|
113 |
+
for g in grad_input:
|
114 |
+
g[g != g] = 0 # replace all nan/inf in gradients to zero
|
115 |
+
|
116 |
+
|
117 |
+
class strLabelConverter(object):
|
118 |
+
"""Convert between str and label.
|
119 |
+
NOTE:
|
120 |
+
Insert `blank` to the alphabet for CTC.
|
121 |
+
Args:
|
122 |
+
alphabet (str): set of the possible characters.
|
123 |
+
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, alphabet, ignore_case=False):
|
127 |
+
self._ignore_case = ignore_case
|
128 |
+
if self._ignore_case:
|
129 |
+
alphabet = alphabet.lower()
|
130 |
+
self.alphabet = alphabet + '-' # for `-1` index
|
131 |
+
|
132 |
+
self.dict = {}
|
133 |
+
for i, char in enumerate(alphabet):
|
134 |
+
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
|
135 |
+
self.dict[char] = i + 1
|
136 |
+
|
137 |
+
def encode(self, text):
|
138 |
+
"""Support batch or single str.
|
139 |
+
Args:
|
140 |
+
text (str or list of str): texts to convert.
|
141 |
+
Returns:
|
142 |
+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
143 |
+
torch.IntTensor [n]: length of each text.
|
144 |
+
"""
|
145 |
+
length = []
|
146 |
+
result = []
|
147 |
+
results = []
|
148 |
+
for item in text:
|
149 |
+
if isinstance(item, bytes): item = item.decode('utf-8', 'strict')
|
150 |
+
length.append(len(item))
|
151 |
+
for char in item:
|
152 |
+
index = self.dict[char]
|
153 |
+
result.append(index)
|
154 |
+
results.append(result)
|
155 |
+
result = []
|
156 |
+
|
157 |
+
return torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length), None
|
158 |
+
|
159 |
+
def decode(self, t, length, raw=False):
|
160 |
+
"""Decode encoded texts back into strs.
|
161 |
+
Args:
|
162 |
+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
163 |
+
torch.IntTensor [n]: length of each text.
|
164 |
+
Raises:
|
165 |
+
AssertionError: when the texts and its length does not match.
|
166 |
+
Returns:
|
167 |
+
text (str or list of str): texts to convert.
|
168 |
+
"""
|
169 |
+
if length.numel() == 1:
|
170 |
+
length = length[0]
|
171 |
+
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
|
172 |
+
length)
|
173 |
+
if raw:
|
174 |
+
return ''.join([self.alphabet[i - 1] for i in t])
|
175 |
+
else:
|
176 |
+
char_list = []
|
177 |
+
for i in range(length):
|
178 |
+
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
179 |
+
char_list.append(self.alphabet[t[i] - 1])
|
180 |
+
return ''.join(char_list)
|
181 |
+
else:
|
182 |
+
# batch mode
|
183 |
+
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
|
184 |
+
t.numel(), length.sum())
|
185 |
+
texts = []
|
186 |
+
index = 0
|
187 |
+
for i in range(length.numel()):
|
188 |
+
l = length[i]
|
189 |
+
texts.append(
|
190 |
+
self.decode(
|
191 |
+
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
192 |
+
index += l
|
193 |
+
return texts
|
models/__init__.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
"""
|
19 |
+
|
20 |
+
import importlib
|
21 |
+
|
22 |
+
|
23 |
+
def find_model_using_name(model_name):
|
24 |
+
"""Import the module "models/[model_name]_model.py".
|
25 |
+
|
26 |
+
In the file, the class called DatasetNameModel() will
|
27 |
+
be instantiated. It has to be a subclass of BaseModel,
|
28 |
+
and it is case-insensitive.
|
29 |
+
"""
|
30 |
+
model_filename = "models." + model_name + "_model"
|
31 |
+
modellib = importlib.import_module(model_filename)
|
32 |
+
model = None
|
33 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
34 |
+
for name, cls in modellib.__dict__.items():
|
35 |
+
if name.lower() == target_model_name.lower() \
|
36 |
+
and issubclass(cls, BaseModel):
|
37 |
+
model = cls
|
38 |
+
|
39 |
+
if model is None:
|
40 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
41 |
+
exit(0)
|
42 |
+
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def get_option_setter(model_name):
|
47 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
48 |
+
model_class = find_model_using_name(model_name)
|
49 |
+
return model_class.modify_commandline_options
|
50 |
+
|
51 |
+
|
52 |
+
def create_model(opt):
|
53 |
+
"""Create a model given the option.
|
54 |
+
|
55 |
+
This function warps the class CustomDatasetDataLoader.
|
56 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
57 |
+
|
58 |
+
Example:
|
59 |
+
>>> from models import create_model
|
60 |
+
>>> model = create_model(opt)
|
61 |
+
"""
|
62 |
+
model = find_model_using_name(opt.model)
|
63 |
+
instance = model(opt)
|
64 |
+
print("model [%s] was created" % type(instance).__name__)
|
65 |
+
return instance
|
models/blocks.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class ResBlocks(nn.Module):
|
7 |
+
def __init__(self, num_blocks, dim, norm, activation, pad_type):
|
8 |
+
super(ResBlocks, self).__init__()
|
9 |
+
self.model = []
|
10 |
+
for i in range(num_blocks):
|
11 |
+
self.model += [ResBlock(dim,
|
12 |
+
norm=norm,
|
13 |
+
activation=activation,
|
14 |
+
pad_type=pad_type)]
|
15 |
+
self.model = nn.Sequential(*self.model)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.model(x)
|
19 |
+
|
20 |
+
|
21 |
+
class ResBlock(nn.Module):
|
22 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
23 |
+
super(ResBlock, self).__init__()
|
24 |
+
model = []
|
25 |
+
model += [Conv2dBlock(dim, dim, 3, 1, 1,
|
26 |
+
norm=norm,
|
27 |
+
activation=activation,
|
28 |
+
pad_type=pad_type)]
|
29 |
+
model += [Conv2dBlock(dim, dim, 3, 1, 1,
|
30 |
+
norm=norm,
|
31 |
+
activation='none',
|
32 |
+
pad_type=pad_type)]
|
33 |
+
self.model = nn.Sequential(*model)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
residual = x
|
37 |
+
out = self.model(x)
|
38 |
+
out += residual
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class ActFirstResBlock(nn.Module):
|
43 |
+
def __init__(self, fin, fout, fhid=None,
|
44 |
+
activation='lrelu', norm='none'):
|
45 |
+
super().__init__()
|
46 |
+
self.learned_shortcut = (fin != fout)
|
47 |
+
self.fin = fin
|
48 |
+
self.fout = fout
|
49 |
+
self.fhid = min(fin, fout) if fhid is None else fhid
|
50 |
+
self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1,
|
51 |
+
padding=1, pad_type='reflect', norm=norm,
|
52 |
+
activation=activation, activation_first=True)
|
53 |
+
self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1,
|
54 |
+
padding=1, pad_type='reflect', norm=norm,
|
55 |
+
activation=activation, activation_first=True)
|
56 |
+
if self.learned_shortcut:
|
57 |
+
self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1,
|
58 |
+
activation='none', use_bias=False)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x_s = self.conv_s(x) if self.learned_shortcut else x
|
62 |
+
dx = self.conv_0(x)
|
63 |
+
dx = self.conv_1(dx)
|
64 |
+
out = x_s + dx
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class LinearBlock(nn.Module):
|
69 |
+
def __init__(self, in_dim, out_dim, norm='none', activation='relu'):
|
70 |
+
super(LinearBlock, self).__init__()
|
71 |
+
use_bias = True
|
72 |
+
self.fc = nn.Linear(in_dim, out_dim, bias=use_bias)
|
73 |
+
|
74 |
+
# initialize normalization
|
75 |
+
norm_dim = out_dim
|
76 |
+
if norm == 'bn':
|
77 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
78 |
+
elif norm == 'in':
|
79 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
80 |
+
elif norm == 'none':
|
81 |
+
self.norm = None
|
82 |
+
else:
|
83 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
84 |
+
|
85 |
+
# initialize activation
|
86 |
+
if activation == 'relu':
|
87 |
+
self.activation = nn.ReLU(inplace=False)
|
88 |
+
elif activation == 'lrelu':
|
89 |
+
self.activation = nn.LeakyReLU(0.2, inplace=False)
|
90 |
+
elif activation == 'tanh':
|
91 |
+
self.activation = nn.Tanh()
|
92 |
+
elif activation == 'none':
|
93 |
+
self.activation = None
|
94 |
+
else:
|
95 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
out = self.fc(x)
|
99 |
+
if self.norm:
|
100 |
+
out = self.norm(out)
|
101 |
+
if self.activation:
|
102 |
+
out = self.activation(out)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class Conv2dBlock(nn.Module):
|
107 |
+
def __init__(self, in_dim, out_dim, ks, st, padding=0,
|
108 |
+
norm='none', activation='relu', pad_type='zero',
|
109 |
+
use_bias=True, activation_first=False):
|
110 |
+
super(Conv2dBlock, self).__init__()
|
111 |
+
self.use_bias = use_bias
|
112 |
+
self.activation_first = activation_first
|
113 |
+
# initialize padding
|
114 |
+
if pad_type == 'reflect':
|
115 |
+
self.pad = nn.ReflectionPad2d(padding)
|
116 |
+
elif pad_type == 'replicate':
|
117 |
+
self.pad = nn.ReplicationPad2d(padding)
|
118 |
+
elif pad_type == 'zero':
|
119 |
+
self.pad = nn.ZeroPad2d(padding)
|
120 |
+
else:
|
121 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
122 |
+
|
123 |
+
# initialize normalization
|
124 |
+
norm_dim = out_dim
|
125 |
+
if norm == 'bn':
|
126 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
127 |
+
elif norm == 'in':
|
128 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
129 |
+
elif norm == 'adain':
|
130 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
131 |
+
elif norm == 'none':
|
132 |
+
self.norm = None
|
133 |
+
else:
|
134 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
135 |
+
|
136 |
+
# initialize activation
|
137 |
+
if activation == 'relu':
|
138 |
+
self.activation = nn.ReLU(inplace=False)
|
139 |
+
elif activation == 'lrelu':
|
140 |
+
self.activation = nn.LeakyReLU(0.2, inplace=False)
|
141 |
+
elif activation == 'tanh':
|
142 |
+
self.activation = nn.Tanh()
|
143 |
+
elif activation == 'none':
|
144 |
+
self.activation = None
|
145 |
+
else:
|
146 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
147 |
+
|
148 |
+
self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
if self.activation_first:
|
152 |
+
if self.activation:
|
153 |
+
x = self.activation(x)
|
154 |
+
x = self.conv(self.pad(x))
|
155 |
+
if self.norm:
|
156 |
+
x = self.norm(x)
|
157 |
+
else:
|
158 |
+
x = self.conv(self.pad(x))
|
159 |
+
if self.norm:
|
160 |
+
x = self.norm(x)
|
161 |
+
if self.activation:
|
162 |
+
x = self.activation(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class AdaptiveInstanceNorm2d(nn.Module):
|
167 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
168 |
+
super(AdaptiveInstanceNorm2d, self).__init__()
|
169 |
+
self.num_features = num_features
|
170 |
+
self.eps = eps
|
171 |
+
self.momentum = momentum
|
172 |
+
self.weight = None
|
173 |
+
self.bias = None
|
174 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
175 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
assert self.weight is not None and \
|
179 |
+
self.bias is not None, "Please assign AdaIN weight first"
|
180 |
+
b, c = x.size(0), x.size(1)
|
181 |
+
running_mean = self.running_mean.repeat(b)
|
182 |
+
running_var = self.running_var.repeat(b)
|
183 |
+
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
184 |
+
out = F.batch_norm(
|
185 |
+
x_reshaped, running_mean, running_var, self.weight, self.bias,
|
186 |
+
True, self.momentum, self.eps)
|
187 |
+
return out.view(b, c, *x.size()[2:])
|
188 |
+
|
189 |
+
def __repr__(self):
|
190 |
+
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|
models/config.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tn_hidden_dim = 512
|
2 |
+
tn_dropout = 0.1
|
3 |
+
tn_nheads = 8
|
4 |
+
tn_dim_feedforward = 512
|
5 |
+
tn_enc_layers = 3
|
6 |
+
tn_dec_layers = 3
|
models/inception.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import models
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = models.inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def fid_inception_v3():
|
167 |
+
"""Build pretrained Inception model for FID computation
|
168 |
+
|
169 |
+
The Inception model for FID computation uses a different set of weights
|
170 |
+
and has a slightly different structure than torchvision's Inception.
|
171 |
+
|
172 |
+
This method first constructs torchvision's Inception and then patches the
|
173 |
+
necessary parts that are different in the FID Inception model.
|
174 |
+
"""
|
175 |
+
inception = models.inception_v3(num_classes=1008,
|
176 |
+
aux_logits=False,
|
177 |
+
weights=None,
|
178 |
+
init_weights=False)
|
179 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
180 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
181 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
182 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
183 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
184 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
185 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
186 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
187 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
188 |
+
|
189 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
190 |
+
inception.load_state_dict(state_dict)
|
191 |
+
return inception
|
192 |
+
|
193 |
+
|
194 |
+
class FIDInceptionA(models.inception.InceptionA):
|
195 |
+
"""InceptionA block patched for FID computation"""
|
196 |
+
def __init__(self, in_channels, pool_features):
|
197 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
branch1x1 = self.branch1x1(x)
|
201 |
+
|
202 |
+
branch5x5 = self.branch5x5_1(x)
|
203 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
204 |
+
|
205 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
206 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
207 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
208 |
+
|
209 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
210 |
+
# its average calculation
|
211 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
212 |
+
count_include_pad=False)
|
213 |
+
branch_pool = self.branch_pool(branch_pool)
|
214 |
+
|
215 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
216 |
+
return torch.cat(outputs, 1)
|
217 |
+
|
218 |
+
|
219 |
+
class FIDInceptionC(models.inception.InceptionC):
|
220 |
+
"""InceptionC block patched for FID computation"""
|
221 |
+
def __init__(self, in_channels, channels_7x7):
|
222 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
branch1x1 = self.branch1x1(x)
|
226 |
+
|
227 |
+
branch7x7 = self.branch7x7_1(x)
|
228 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
229 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
230 |
+
|
231 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
232 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
233 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
234 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
235 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
236 |
+
|
237 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
238 |
+
# its average calculation
|
239 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
240 |
+
count_include_pad=False)
|
241 |
+
branch_pool = self.branch_pool(branch_pool)
|
242 |
+
|
243 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
244 |
+
return torch.cat(outputs, 1)
|
245 |
+
|
246 |
+
|
247 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
248 |
+
"""First InceptionE block patched for FID computation"""
|
249 |
+
def __init__(self, in_channels):
|
250 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
branch1x1 = self.branch1x1(x)
|
254 |
+
|
255 |
+
branch3x3 = self.branch3x3_1(x)
|
256 |
+
branch3x3 = [
|
257 |
+
self.branch3x3_2a(branch3x3),
|
258 |
+
self.branch3x3_2b(branch3x3),
|
259 |
+
]
|
260 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
261 |
+
|
262 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
263 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
264 |
+
branch3x3dbl = [
|
265 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
266 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
267 |
+
]
|
268 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
269 |
+
|
270 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
271 |
+
# its average calculation
|
272 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
273 |
+
count_include_pad=False)
|
274 |
+
branch_pool = self.branch_pool(branch_pool)
|
275 |
+
|
276 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
277 |
+
return torch.cat(outputs, 1)
|
278 |
+
|
279 |
+
|
280 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
281 |
+
"""Second InceptionE block patched for FID computation"""
|
282 |
+
def __init__(self, in_channels):
|
283 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
branch1x1 = self.branch1x1(x)
|
287 |
+
|
288 |
+
branch3x3 = self.branch3x3_1(x)
|
289 |
+
branch3x3 = [
|
290 |
+
self.branch3x3_2a(branch3x3),
|
291 |
+
self.branch3x3_2b(branch3x3),
|
292 |
+
]
|
293 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
294 |
+
|
295 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
296 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
297 |
+
branch3x3dbl = [
|
298 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
299 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
300 |
+
]
|
301 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
302 |
+
|
303 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
304 |
+
# pooling. This is likely an error in this specific Inception
|
305 |
+
# implementation, as other Inception models use average pooling here
|
306 |
+
# (which matches the description in the paper).
|
307 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
308 |
+
branch_pool = self.branch_pool(branch_pool)
|
309 |
+
|
310 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
311 |
+
return torch.cat(outputs, 1)
|
models/model.py
ADDED
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data
|
2 |
+
from torch.nn import CTCLoss
|
3 |
+
from torch.nn.utils import clip_grad_norm_
|
4 |
+
import sys
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
from models.inception import InceptionV3
|
8 |
+
from models.transformer import *
|
9 |
+
from util.augmentations import OCRAugment
|
10 |
+
from util.misc import SmoothedValue
|
11 |
+
from util.text import get_generator, AugmentedGenerator
|
12 |
+
from .BigGAN_networks import *
|
13 |
+
from .OCR_network import *
|
14 |
+
from models.blocks import Conv2dBlock, ResBlocks
|
15 |
+
from util.util import loss_hinge_dis, loss_hinge_gen, make_one_hot
|
16 |
+
|
17 |
+
import models.config as config
|
18 |
+
from .positional_encodings import PositionalEncoding1D
|
19 |
+
from models.unifont_module import UnifontModule
|
20 |
+
from PIL import Image
|
21 |
+
|
22 |
+
|
23 |
+
def get_rgb(x):
|
24 |
+
R = 255 - int(int(x > 0.5) * 255 * (x - 0.5) / 0.5)
|
25 |
+
G = 0
|
26 |
+
B = 255 + int(int(x < 0.5) * 255 * (x - 0.5) / 0.5)
|
27 |
+
return R, G, B
|
28 |
+
|
29 |
+
|
30 |
+
def get_page_from_words(word_lists, MAX_IMG_WIDTH=800):
|
31 |
+
line_all = []
|
32 |
+
line_t = []
|
33 |
+
|
34 |
+
width_t = 0
|
35 |
+
|
36 |
+
for i in word_lists:
|
37 |
+
|
38 |
+
width_t = width_t + i.shape[1] + 16
|
39 |
+
|
40 |
+
if width_t > MAX_IMG_WIDTH:
|
41 |
+
line_all.append(np.concatenate(line_t, 1))
|
42 |
+
|
43 |
+
line_t = []
|
44 |
+
|
45 |
+
width_t = i.shape[1] + 16
|
46 |
+
|
47 |
+
line_t.append(i)
|
48 |
+
line_t.append(np.ones((i.shape[0], 16)))
|
49 |
+
|
50 |
+
if len(line_all) == 0:
|
51 |
+
line_all.append(np.concatenate(line_t, 1))
|
52 |
+
|
53 |
+
max_lin_widths = MAX_IMG_WIDTH # max([i.shape[1] for i in line_all])
|
54 |
+
gap_h = np.ones([16, max_lin_widths])
|
55 |
+
|
56 |
+
page_ = []
|
57 |
+
|
58 |
+
for l in line_all:
|
59 |
+
pad_ = np.ones([l.shape[0], max_lin_widths - l.shape[1]])
|
60 |
+
|
61 |
+
page_.append(np.concatenate([l, pad_], 1))
|
62 |
+
page_.append(gap_h)
|
63 |
+
|
64 |
+
page = np.concatenate(page_, 0)
|
65 |
+
|
66 |
+
return page * 255
|
67 |
+
|
68 |
+
|
69 |
+
class FCNDecoder(nn.Module):
|
70 |
+
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
|
71 |
+
super(FCNDecoder, self).__init__()
|
72 |
+
|
73 |
+
self.model = []
|
74 |
+
self.model += [ResBlocks(n_res, dim, res_norm,
|
75 |
+
activ, pad_type=pad_type)]
|
76 |
+
for i in range(ups):
|
77 |
+
self.model += [nn.Upsample(scale_factor=2),
|
78 |
+
Conv2dBlock(dim, dim // 2, 5, 1, 2,
|
79 |
+
norm='in',
|
80 |
+
activation=activ,
|
81 |
+
pad_type=pad_type)]
|
82 |
+
dim //= 2
|
83 |
+
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
|
84 |
+
norm='none',
|
85 |
+
activation='tanh',
|
86 |
+
pad_type=pad_type)]
|
87 |
+
self.model = nn.Sequential(*self.model)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
y = self.model(x)
|
91 |
+
|
92 |
+
return y
|
93 |
+
|
94 |
+
|
95 |
+
class Generator(nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, args):
|
98 |
+
super(Generator, self).__init__()
|
99 |
+
self.args = args
|
100 |
+
INP_CHANNEL = 1
|
101 |
+
|
102 |
+
encoder_layer = TransformerEncoderLayer(config.tn_hidden_dim, config.tn_nheads,
|
103 |
+
config.tn_dim_feedforward,
|
104 |
+
config.tn_dropout, "relu", True)
|
105 |
+
encoder_norm = nn.LayerNorm(config.tn_hidden_dim) if True else None
|
106 |
+
self.encoder = TransformerEncoder(encoder_layer, config.tn_enc_layers, encoder_norm)
|
107 |
+
|
108 |
+
decoder_layer = TransformerDecoderLayer(config.tn_hidden_dim, config.tn_nheads,
|
109 |
+
config.tn_dim_feedforward,
|
110 |
+
config.tn_dropout, "relu", True)
|
111 |
+
decoder_norm = nn.LayerNorm(config.tn_hidden_dim)
|
112 |
+
self.decoder = TransformerDecoder(decoder_layer, config.tn_dec_layers, decoder_norm,
|
113 |
+
return_intermediate=True)
|
114 |
+
|
115 |
+
self.Feat_Encoder = models.resnet18(weights='ResNet18_Weights.DEFAULT')
|
116 |
+
self.Feat_Encoder.conv1 = nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
117 |
+
self.Feat_Encoder.fc = nn.Identity()
|
118 |
+
self.Feat_Encoder.avgpool = nn.Identity()
|
119 |
+
|
120 |
+
# self.query_embed = nn.Embedding(self.args.vocab_size, self.args.tn_hidden_dim)
|
121 |
+
self.query_embed = UnifontModule(
|
122 |
+
config.tn_dim_feedforward,
|
123 |
+
self.args.alphabet + self.args.special_alphabet,
|
124 |
+
input_type=self.args.query_input,
|
125 |
+
device=self.args.device
|
126 |
+
)
|
127 |
+
|
128 |
+
self.pos_encoder = PositionalEncoding1D(config.tn_hidden_dim)
|
129 |
+
|
130 |
+
self.linear_q = nn.Linear(config.tn_dim_feedforward, config.tn_dim_feedforward * 8)
|
131 |
+
|
132 |
+
self.DEC = FCNDecoder(res_norm='in', dim=config.tn_hidden_dim)
|
133 |
+
|
134 |
+
self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
|
135 |
+
|
136 |
+
def evaluate(self, style_images, queries):
|
137 |
+
style = self.compute_style(style_images)
|
138 |
+
|
139 |
+
results = []
|
140 |
+
|
141 |
+
for i in range(queries.shape[1]):
|
142 |
+
query = queries[:, i, :]
|
143 |
+
h = self.generate(style, query)
|
144 |
+
|
145 |
+
results.append(h.detach())
|
146 |
+
|
147 |
+
return results
|
148 |
+
|
149 |
+
def compute_style(self, style_images):
|
150 |
+
B, N, R, C = style_images.shape
|
151 |
+
FEAT_ST = self.Feat_Encoder(style_images.view(B * N, 1, R, C))
|
152 |
+
FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
|
153 |
+
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2, 0, 1)
|
154 |
+
memory = self.encoder(FEAT_ST_ENC)
|
155 |
+
return memory
|
156 |
+
|
157 |
+
def generate(self, style_vector, query):
|
158 |
+
query_embed = self.query_embed(query).permute(1, 0, 2)
|
159 |
+
|
160 |
+
tgt = torch.zeros_like(query_embed)
|
161 |
+
hs = self.decoder(tgt, style_vector, query_pos=query_embed)
|
162 |
+
|
163 |
+
h = hs.transpose(1, 2)[-1]
|
164 |
+
|
165 |
+
if self.args.add_noise:
|
166 |
+
h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
|
167 |
+
|
168 |
+
h = self.linear_q(h)
|
169 |
+
h = h.contiguous()
|
170 |
+
|
171 |
+
h = h.view(h.size(0), h.shape[1] * 2, 4, -1)
|
172 |
+
h = h.permute(0, 3, 2, 1)
|
173 |
+
|
174 |
+
h = self.DEC(h)
|
175 |
+
|
176 |
+
return h
|
177 |
+
|
178 |
+
def forward(self, style_images, query):
|
179 |
+
enc_attn_weights, dec_attn_weights = [], []
|
180 |
+
|
181 |
+
self.hooks = [
|
182 |
+
|
183 |
+
self.encoder.layers[-1].self_attn.register_forward_hook(
|
184 |
+
lambda self, input, output: enc_attn_weights.append(output[1])
|
185 |
+
),
|
186 |
+
self.decoder.layers[-1].multihead_attn.register_forward_hook(
|
187 |
+
lambda self, input, output: dec_attn_weights.append(output[1])
|
188 |
+
),
|
189 |
+
]
|
190 |
+
|
191 |
+
style = self.compute_style(style_images)
|
192 |
+
|
193 |
+
h = self.generate(style, query)
|
194 |
+
|
195 |
+
self.dec_attn_weights = dec_attn_weights[-1].detach()
|
196 |
+
self.enc_attn_weights = enc_attn_weights[-1].detach()
|
197 |
+
|
198 |
+
for hook in self.hooks:
|
199 |
+
hook.remove()
|
200 |
+
|
201 |
+
return h, style
|
202 |
+
|
203 |
+
|
204 |
+
class VATr(nn.Module):
|
205 |
+
|
206 |
+
def __init__(self, args):
|
207 |
+
super(VATr, self).__init__()
|
208 |
+
self.args = args
|
209 |
+
self.args.vocab_size = len(args.alphabet)
|
210 |
+
|
211 |
+
self.epsilon = 1e-7
|
212 |
+
self.netG = Generator(self.args).to(self.args.device)
|
213 |
+
self.netD = Discriminator(
|
214 |
+
resolution=self.args.resolution, crop_size=args.d_crop_size,
|
215 |
+
).to(self.args.device)
|
216 |
+
|
217 |
+
self.netW = WDiscriminator(resolution=self.args.resolution, n_classes=self.args.vocab_size, output_dim=self.args.num_writers)
|
218 |
+
self.netW = self.netW.to(self.args.device)
|
219 |
+
self.netconverter = strLabelConverter(self.args.alphabet + self.args.special_alphabet)
|
220 |
+
|
221 |
+
self.netOCR = CRNN(self.args).to(self.args.device)
|
222 |
+
|
223 |
+
self.ocr_augmenter = OCRAugment(prob=0.5, no=3)
|
224 |
+
self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
|
225 |
+
|
226 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
227 |
+
self.inception = InceptionV3([block_idx]).to(self.args.device)
|
228 |
+
|
229 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
|
230 |
+
lr=self.args.g_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
231 |
+
|
232 |
+
self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
|
233 |
+
lr=self.args.ocr_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
234 |
+
|
235 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
|
236 |
+
lr=self.args.d_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
237 |
+
|
238 |
+
self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
|
239 |
+
lr=self.args.w_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
240 |
+
|
241 |
+
self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
|
242 |
+
|
243 |
+
self.optimizer_G.zero_grad()
|
244 |
+
self.optimizer_OCR.zero_grad()
|
245 |
+
self.optimizer_D.zero_grad()
|
246 |
+
self.optimizer_wl.zero_grad()
|
247 |
+
|
248 |
+
self.loss_G = 0
|
249 |
+
self.loss_D = 0
|
250 |
+
self.loss_Dfake = 0
|
251 |
+
self.loss_Dreal = 0
|
252 |
+
self.loss_OCR_fake = 0
|
253 |
+
self.loss_OCR_real = 0
|
254 |
+
self.loss_w_fake = 0
|
255 |
+
self.loss_w_real = 0
|
256 |
+
self.Lcycle = 0
|
257 |
+
self.d_acc = SmoothedValue()
|
258 |
+
|
259 |
+
self.word_generator = get_generator(args)
|
260 |
+
|
261 |
+
self.epoch = 0
|
262 |
+
|
263 |
+
with open('mytext.txt', 'r', encoding='utf-8') as f:
|
264 |
+
self.text = f.read()
|
265 |
+
self.text = self.text.replace('\n', ' ')
|
266 |
+
self.text = self.text.replace('\n', ' ')
|
267 |
+
self.text = ''.join(c for c in self.text if c in (self.args.alphabet + self.args.special_alphabet)) # just to avoid problems with the font dataset
|
268 |
+
self.text = [word.encode() for word in self.text.split()] # [:args.num_examples]
|
269 |
+
|
270 |
+
self.eval_text_encode, self.eval_len_text, self.eval_encode_pos = self.netconverter.encode(self.text)
|
271 |
+
self.eval_text_encode = self.eval_text_encode.to(self.args.device).repeat(self.args.batch_size, 1, 1)
|
272 |
+
|
273 |
+
self.rv_sample_size = 64 * 4
|
274 |
+
self.last_fakes = []
|
275 |
+
|
276 |
+
def update_last_fakes(self, fakes):
|
277 |
+
for fake in fakes:
|
278 |
+
self.last_fakes.append(fake)
|
279 |
+
self.last_fakes = self.last_fakes[-self.rv_sample_size:]
|
280 |
+
|
281 |
+
def update_acc(self, pred_real, pred_fake):
|
282 |
+
correct = (pred_real >= 0.5).float().sum() + (pred_fake < 0.5).float().sum()
|
283 |
+
self.d_acc.update(correct / (len(pred_real) + len(pred_fake)))
|
284 |
+
|
285 |
+
def set_text_aug_strength(self, strength):
|
286 |
+
if not isinstance(self.word_generator, AugmentedGenerator):
|
287 |
+
print("WARNING: Text generator is not augmented, strength cannot be set")
|
288 |
+
else:
|
289 |
+
self.word_generator.set_strength(strength)
|
290 |
+
|
291 |
+
def get_text_aug_strength(self):
|
292 |
+
if isinstance(self.word_generator, AugmentedGenerator):
|
293 |
+
return self.word_generator.strength
|
294 |
+
else:
|
295 |
+
return 0.0
|
296 |
+
|
297 |
+
def update_parameters(self, epoch: int):
|
298 |
+
self.epoch = epoch
|
299 |
+
self.netD.update_parameters(epoch)
|
300 |
+
self.netW.update_parameters(epoch)
|
301 |
+
|
302 |
+
def get_text_sample(self, size: int) -> list:
|
303 |
+
return [self.word_generator.generate() for _ in range(size)]
|
304 |
+
|
305 |
+
def _generate_fakes(self, ST, eval_text_encode=None, eval_len_text=None):
|
306 |
+
if eval_text_encode == None:
|
307 |
+
eval_text_encode = self.eval_text_encode
|
308 |
+
if eval_len_text == None:
|
309 |
+
eval_len_text = self.eval_len_text
|
310 |
+
|
311 |
+
self.fakes = self.netG.evaluate(ST, eval_text_encode)
|
312 |
+
|
313 |
+
np_fakes = []
|
314 |
+
for batch_idx in range(self.fakes[0].shape[0]):
|
315 |
+
for idx, fake in enumerate(self.fakes):
|
316 |
+
fake = fake[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution]
|
317 |
+
fake = (fake + 1) / 2
|
318 |
+
np_fakes.append(fake.cpu().numpy())
|
319 |
+
return np_fakes
|
320 |
+
|
321 |
+
def _generate_page(self, ST, SLEN, eval_text_encode=None, eval_len_text=None, eval_encode_pos=None, lwidth=260, rwidth=980):
|
322 |
+
# ST -> Style?
|
323 |
+
|
324 |
+
if eval_text_encode == None:
|
325 |
+
eval_text_encode = self.eval_text_encode
|
326 |
+
if eval_len_text == None:
|
327 |
+
eval_len_text = self.eval_len_text
|
328 |
+
if eval_encode_pos is None:
|
329 |
+
eval_encode_pos = self.eval_encode_pos
|
330 |
+
|
331 |
+
text_encode, text_len, _ = self.netconverter.encode(self.args.special_alphabet)
|
332 |
+
symbols = self.netG.query_embed.symbols[text_encode].reshape(-1, 16, 16).cpu().numpy()
|
333 |
+
imgs = [Image.fromarray(s).resize((32, 32), resample=0) for s in symbols]
|
334 |
+
special_examples = 1 - np.concatenate([np.array(i) for i in imgs], axis=-1)
|
335 |
+
|
336 |
+
self.fakes = self.netG.evaluate(ST, eval_text_encode)
|
337 |
+
|
338 |
+
page1s = []
|
339 |
+
page2s = []
|
340 |
+
|
341 |
+
for batch_idx in range(ST.shape[0]):
|
342 |
+
|
343 |
+
word_t = []
|
344 |
+
word_l = []
|
345 |
+
|
346 |
+
gap = np.ones([self.args.img_height, 16])
|
347 |
+
|
348 |
+
line_wids = []
|
349 |
+
|
350 |
+
for idx, fake_ in enumerate(self.fakes):
|
351 |
+
|
352 |
+
word_t.append((fake_[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution].cpu().numpy() + 1) / 2)
|
353 |
+
|
354 |
+
word_t.append(gap)
|
355 |
+
|
356 |
+
if sum(t.shape[-1] for t in word_t) >= rwidth or idx == len(self.fakes) - 1 or (len(self.fakes) - len(self.args.special_alphabet) - 1) == idx:
|
357 |
+
line_ = np.concatenate(word_t, -1)
|
358 |
+
|
359 |
+
word_l.append(line_)
|
360 |
+
line_wids.append(line_.shape[1])
|
361 |
+
|
362 |
+
word_t = []
|
363 |
+
|
364 |
+
# add the examples from the UnifontModules
|
365 |
+
word_l.append(special_examples)
|
366 |
+
line_wids.append(special_examples.shape[1])
|
367 |
+
|
368 |
+
gap_h = np.ones([16, max(line_wids)])
|
369 |
+
|
370 |
+
page_ = []
|
371 |
+
|
372 |
+
for l in word_l:
|
373 |
+
pad_ = np.ones([self.args.img_height, max(line_wids) - l.shape[1]])
|
374 |
+
|
375 |
+
page_.append(np.concatenate([l, pad_], 1))
|
376 |
+
page_.append(gap_h)
|
377 |
+
|
378 |
+
page1 = np.concatenate(page_, 0)
|
379 |
+
|
380 |
+
word_t = []
|
381 |
+
word_l = []
|
382 |
+
|
383 |
+
|
384 |
+
line_wids = []
|
385 |
+
|
386 |
+
sdata_ = [i.unsqueeze(1) for i in torch.unbind(ST, 1)]
|
387 |
+
gap = np.ones([sdata_[0].shape[-2], 16])
|
388 |
+
|
389 |
+
for idx, st in enumerate((sdata_)):
|
390 |
+
|
391 |
+
word_t.append((st[batch_idx, 0, :, :int(SLEN.cpu().numpy()[batch_idx][idx])].cpu().numpy() + 1) / 2)
|
392 |
+
# word_t.append((st[batch_idx, 0, :, :].cpu().numpy() + 1) / 2)
|
393 |
+
|
394 |
+
word_t.append(gap)
|
395 |
+
|
396 |
+
if sum(t.shape[-1] for t in word_t) >= lwidth or idx == len(sdata_) - 1:
|
397 |
+
line_ = np.concatenate(word_t, -1)
|
398 |
+
|
399 |
+
word_l.append(line_)
|
400 |
+
line_wids.append(line_.shape[1])
|
401 |
+
|
402 |
+
word_t = []
|
403 |
+
|
404 |
+
gap_h = np.ones([16, max(line_wids)])
|
405 |
+
|
406 |
+
page_ = []
|
407 |
+
|
408 |
+
for l in word_l:
|
409 |
+
pad_ = np.ones([sdata_[0].shape[-2], max(line_wids) - l.shape[1]])
|
410 |
+
|
411 |
+
page_.append(np.concatenate([l, pad_], 1))
|
412 |
+
page_.append(gap_h)
|
413 |
+
|
414 |
+
page2 = np.concatenate(page_, 0)
|
415 |
+
|
416 |
+
merge_w_size = max(page1.shape[0], page2.shape[0])
|
417 |
+
|
418 |
+
if page1.shape[0] != merge_w_size:
|
419 |
+
page1 = np.concatenate([page1, np.ones([merge_w_size - page1.shape[0], page1.shape[1]])], 0)
|
420 |
+
|
421 |
+
if page2.shape[0] != merge_w_size:
|
422 |
+
page2 = np.concatenate([page2, np.ones([merge_w_size - page2.shape[0], page2.shape[1]])], 0)
|
423 |
+
|
424 |
+
page1s.append(page1)
|
425 |
+
page2s.append(page2)
|
426 |
+
|
427 |
+
# page = np.concatenate([page2, page1], 1)
|
428 |
+
|
429 |
+
page1s_ = np.concatenate(page1s, 0)
|
430 |
+
max_wid = max([i.shape[1] for i in page2s])
|
431 |
+
padded_page2s = []
|
432 |
+
|
433 |
+
for para in page2s:
|
434 |
+
padded_page2s.append(np.concatenate([para, np.ones([para.shape[0], max_wid - para.shape[1]])], 1))
|
435 |
+
|
436 |
+
padded_page2s_ = np.concatenate(padded_page2s, 0)
|
437 |
+
|
438 |
+
return np.concatenate([padded_page2s_, page1s_], 1)
|
439 |
+
|
440 |
+
def get_current_losses(self):
|
441 |
+
|
442 |
+
losses = {}
|
443 |
+
|
444 |
+
losses['G'] = self.loss_G
|
445 |
+
losses['D'] = self.loss_D
|
446 |
+
losses['Dfake'] = self.loss_Dfake
|
447 |
+
losses['Dreal'] = self.loss_Dreal
|
448 |
+
losses['OCR_fake'] = self.loss_OCR_fake
|
449 |
+
losses['OCR_real'] = self.loss_OCR_real
|
450 |
+
losses['w_fake'] = self.loss_w_fake
|
451 |
+
losses['w_real'] = self.loss_w_real
|
452 |
+
losses['cycle'] = self.Lcycle
|
453 |
+
|
454 |
+
return losses
|
455 |
+
|
456 |
+
def _set_input(self, input):
|
457 |
+
self.input = input
|
458 |
+
|
459 |
+
self.real = self.input['img'].to(self.args.device)
|
460 |
+
self.label = self.input['label']
|
461 |
+
|
462 |
+
self.set_ocr_data(self.input['img'], self.input['label'])
|
463 |
+
|
464 |
+
self.sdata = self.input['simg'].to(self.args.device)
|
465 |
+
self.slabels = self.input['slabels']
|
466 |
+
|
467 |
+
self.ST_LEN = self.input['swids']
|
468 |
+
|
469 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
470 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
471 |
+
Parameters:
|
472 |
+
nets (network list) -- a list of networks
|
473 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
474 |
+
"""
|
475 |
+
if not isinstance(nets, list):
|
476 |
+
nets = [nets]
|
477 |
+
for net in nets:
|
478 |
+
if net is not None:
|
479 |
+
for param in net.parameters():
|
480 |
+
param.requires_grad = requires_grad
|
481 |
+
|
482 |
+
def forward(self):
|
483 |
+
self.text_encode, self.len_text, self.encode_pos = self.netconverter.encode(self.label)
|
484 |
+
self.text_encode = self.text_encode.to(self.args.device).detach()
|
485 |
+
self.len_text = self.len_text.detach()
|
486 |
+
|
487 |
+
self.words = [self.word_generator.generate().encode('utf-8') for _ in range(self.args.batch_size)]
|
488 |
+
self.text_encode_fake, self.len_text_fake, self.encode_pos_fake = self.netconverter.encode(self.words)
|
489 |
+
self.text_encode_fake = self.text_encode_fake.to(self.args.device)
|
490 |
+
self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.args.vocab_size).to(
|
491 |
+
self.args.device)
|
492 |
+
|
493 |
+
self.fake, self.style = self.netG(self.sdata, self.text_encode_fake)
|
494 |
+
|
495 |
+
self.update_last_fakes(self.fake)
|
496 |
+
|
497 |
+
def pad_width(self, t, new_width):
|
498 |
+
result = torch.ones((t.size(0), t.size(1), t.size(2), new_width), device=t.device)
|
499 |
+
result[:,:,:,:t.size(-1)] = t
|
500 |
+
|
501 |
+
return result
|
502 |
+
|
503 |
+
def compute_real_ocr_loss(self, ocr_network = None):
|
504 |
+
network = ocr_network if ocr_network is not None else self.netOCR
|
505 |
+
real_input = self.ocr_images
|
506 |
+
input_images = real_input
|
507 |
+
input_labels = self.ocr_labels
|
508 |
+
|
509 |
+
input_images = input_images.detach()
|
510 |
+
|
511 |
+
if self.ocr_augmenter is not None:
|
512 |
+
input_images = self.ocr_augmenter(input_images)
|
513 |
+
|
514 |
+
pred_real = network(input_images)
|
515 |
+
preds_size = torch.IntTensor([pred_real.size(0)] * len(input_labels)).detach()
|
516 |
+
text_encode, len_text, _ = self.netconverter.encode(input_labels)
|
517 |
+
|
518 |
+
loss = self.OCR_criterion(pred_real, text_encode.detach(), preds_size, len_text.detach())
|
519 |
+
|
520 |
+
return torch.mean(loss[~torch.isnan(loss)])
|
521 |
+
|
522 |
+
def compute_fake_ocr_loss(self, ocr_network = None):
|
523 |
+
network = ocr_network if ocr_network is not None else self.netOCR
|
524 |
+
|
525 |
+
pred_fake_OCR = network(self.fake)
|
526 |
+
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.args.batch_size).detach()
|
527 |
+
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size,
|
528 |
+
self.len_text_fake.detach())
|
529 |
+
return torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
|
530 |
+
|
531 |
+
def set_ocr_data(self, images, labels):
|
532 |
+
self.ocr_images = images.to(self.args.device)
|
533 |
+
self.ocr_labels = labels
|
534 |
+
|
535 |
+
def backward_D_OCR(self):
|
536 |
+
self.real.__repr__()
|
537 |
+
self.fake.__repr__()
|
538 |
+
pred_real = self.netD(self.real.detach())
|
539 |
+
pred_fake = self.netD(**{'x': self.fake.detach()})
|
540 |
+
|
541 |
+
self.update_acc(pred_real, pred_fake)
|
542 |
+
|
543 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(),
|
544 |
+
self.len_text.detach(), True)
|
545 |
+
|
546 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
547 |
+
|
548 |
+
if not self.args.no_ocr_loss:
|
549 |
+
self.loss_OCR_real = self.compute_real_ocr_loss()
|
550 |
+
loss_total = self.loss_D + self.loss_OCR_real
|
551 |
+
else:
|
552 |
+
loss_total = self.loss_D
|
553 |
+
|
554 |
+
# backward
|
555 |
+
loss_total.backward()
|
556 |
+
if not self.args.no_ocr_loss:
|
557 |
+
self.clean_grad(self.netOCR.parameters())
|
558 |
+
|
559 |
+
return loss_total
|
560 |
+
|
561 |
+
def clean_grad(self, params):
|
562 |
+
for param in params:
|
563 |
+
param.grad[param.grad != param.grad] = 0
|
564 |
+
param.grad[torch.isnan(param.grad)] = 0
|
565 |
+
param.grad[torch.isinf(param.grad)] = 0
|
566 |
+
|
567 |
+
def backward_D_WL(self):
|
568 |
+
# Real
|
569 |
+
pred_real = self.netD(self.real.detach())
|
570 |
+
|
571 |
+
pred_fake = self.netD(**{'x': self.fake.detach()})
|
572 |
+
|
573 |
+
self.update_acc(pred_real, pred_fake)
|
574 |
+
|
575 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(),
|
576 |
+
self.len_text.detach(), True)
|
577 |
+
|
578 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
579 |
+
|
580 |
+
if not self.args.no_writer_loss:
|
581 |
+
self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(self.args.device)).mean()
|
582 |
+
# total loss
|
583 |
+
loss_total = self.loss_D + self.loss_w_real * self.args.writer_loss_weight
|
584 |
+
else:
|
585 |
+
loss_total = self.loss_D
|
586 |
+
|
587 |
+
# backward
|
588 |
+
loss_total.backward()
|
589 |
+
|
590 |
+
return loss_total
|
591 |
+
|
592 |
+
def optimize_D_WL(self):
|
593 |
+
self.forward()
|
594 |
+
self.set_requires_grad([self.netD], True)
|
595 |
+
self.set_requires_grad([self.netOCR], False)
|
596 |
+
self.set_requires_grad([self.netW], True)
|
597 |
+
self.set_requires_grad([self.netW], True)
|
598 |
+
|
599 |
+
self.optimizer_D.zero_grad()
|
600 |
+
self.optimizer_wl.zero_grad()
|
601 |
+
|
602 |
+
self.backward_D_WL()
|
603 |
+
|
604 |
+
def optimize_D_WL_step(self):
|
605 |
+
self.optimizer_D.step()
|
606 |
+
self.optimizer_wl.step()
|
607 |
+
self.optimizer_D.zero_grad()
|
608 |
+
self.optimizer_wl.zero_grad()
|
609 |
+
|
610 |
+
def compute_cycle_loss(self):
|
611 |
+
fake_input = torch.ones_like(self.sdata)
|
612 |
+
width = min(self.sdata.size(-1), self.fake.size(-1))
|
613 |
+
fake_input[:, :, :, :width] = self.fake.repeat(1, 15, 1, 1)[:, :, :, :width]
|
614 |
+
with torch.no_grad():
|
615 |
+
fake_style = self.netG.compute_style(fake_input)
|
616 |
+
|
617 |
+
return torch.sum(torch.abs(self.style.detach() - fake_style), dim=1).mean()
|
618 |
+
|
619 |
+
def backward_G_only(self):
|
620 |
+
self.gb_alpha = 0.7
|
621 |
+
if self.args.is_cycle:
|
622 |
+
self.Lcycle = self.compute_cycle_loss()
|
623 |
+
|
624 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
|
625 |
+
|
626 |
+
compute_ocr = not self.args.no_ocr_loss
|
627 |
+
|
628 |
+
if compute_ocr:
|
629 |
+
self.loss_OCR_fake = self.compute_fake_ocr_loss()
|
630 |
+
|
631 |
+
self.loss_G = self.loss_G + self.Lcycle
|
632 |
+
|
633 |
+
if compute_ocr:
|
634 |
+
self.loss_T = self.loss_G + self.loss_OCR_fake
|
635 |
+
else:
|
636 |
+
self.loss_T = self.loss_G
|
637 |
+
|
638 |
+
if compute_ocr:
|
639 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
|
640 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
641 |
+
|
642 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
|
643 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
644 |
+
|
645 |
+
self.loss_T.backward(retain_graph=True)
|
646 |
+
|
647 |
+
if compute_ocr:
|
648 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
649 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
650 |
+
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR))
|
651 |
+
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
|
652 |
+
self.loss_T = self.loss_G + self.loss_OCR_fake
|
653 |
+
else:
|
654 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
655 |
+
a = 1
|
656 |
+
self.loss_T = self.loss_G
|
657 |
+
|
658 |
+
if a is None:
|
659 |
+
print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv))
|
660 |
+
if a > 1000 or a < 0.0001:
|
661 |
+
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
|
662 |
+
|
663 |
+
self.loss_T.backward(retain_graph=True)
|
664 |
+
if compute_ocr:
|
665 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
666 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
667 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
|
668 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
669 |
+
|
670 |
+
with torch.no_grad():
|
671 |
+
self.loss_T.backward()
|
672 |
+
if compute_ocr:
|
673 |
+
if any(torch.isnan(torch.unsqueeze(self.loss_OCR_fake, dim=0))) or torch.isnan(self.loss_G):
|
674 |
+
print('loss OCR fake: ', self.loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
|
675 |
+
sys.exit()
|
676 |
+
|
677 |
+
def backward_G_WL(self):
|
678 |
+
self.gb_alpha = 0.7
|
679 |
+
if self.args.is_cycle:
|
680 |
+
self.Lcycle = self.compute_cycle_loss()
|
681 |
+
|
682 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
|
683 |
+
|
684 |
+
if not self.args.no_writer_loss:
|
685 |
+
self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(self.args.device)).mean()
|
686 |
+
|
687 |
+
self.loss_G = self.loss_G + self.Lcycle
|
688 |
+
|
689 |
+
if not self.args.no_writer_loss:
|
690 |
+
self.loss_T = self.loss_G + self.loss_w_fake * self.args.writer_loss_weight
|
691 |
+
else:
|
692 |
+
self.loss_T = self.loss_G
|
693 |
+
|
694 |
+
self.loss_T.backward(retain_graph=True)
|
695 |
+
|
696 |
+
if not self.args.no_writer_loss:
|
697 |
+
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
698 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
699 |
+
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_WL))
|
700 |
+
self.loss_w_fake = a.detach() * self.loss_w_fake
|
701 |
+
self.loss_T = self.loss_G + self.loss_w_fake
|
702 |
+
else:
|
703 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
704 |
+
a = 1
|
705 |
+
self.loss_T = self.loss_G
|
706 |
+
|
707 |
+
if a is None:
|
708 |
+
print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv))
|
709 |
+
if a > 1000 or a < 0.0001:
|
710 |
+
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
|
711 |
+
|
712 |
+
self.loss_T.backward(retain_graph=True)
|
713 |
+
|
714 |
+
if not self.args.no_writer_loss:
|
715 |
+
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
716 |
+
self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
|
717 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
|
718 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
719 |
+
|
720 |
+
with torch.no_grad():
|
721 |
+
self.loss_T.backward()
|
722 |
+
|
723 |
+
def backward_G(self):
|
724 |
+
self.opt.gb_alpha = 0.7
|
725 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(),
|
726 |
+
self.opt.mask_loss)
|
727 |
+
# OCR loss on real data
|
728 |
+
compute_ocr = not self.args.no_ocr_loss
|
729 |
+
|
730 |
+
if compute_ocr:
|
731 |
+
self.loss_OCR_fake = self.compute_fake_ocr_loss()
|
732 |
+
else:
|
733 |
+
self.loss_OCR_fake = 0.0
|
734 |
+
|
735 |
+
self.loss_w_fake = self.netW(self.fake, self.wcl)
|
736 |
+
# self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
|
737 |
+
# total loss
|
738 |
+
|
739 |
+
# l1 = self.params[0]*self.loss_G
|
740 |
+
# l2 = self.params[0]*self.loss_OCR_fake
|
741 |
+
# l3 = self.params[0]*self.loss_w_fake
|
742 |
+
self.loss_G_ = 10 * self.loss_G + self.loss_w_fake
|
743 |
+
self.loss_T = self.loss_G_ + self.loss_OCR_fake
|
744 |
+
|
745 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
|
746 |
+
|
747 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
748 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
|
749 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
750 |
+
|
751 |
+
if not False:
|
752 |
+
|
753 |
+
self.loss_T.backward(retain_graph=True)
|
754 |
+
|
755 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
756 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
|
757 |
+
# grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
758 |
+
|
759 |
+
a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR))
|
760 |
+
|
761 |
+
# a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
|
762 |
+
|
763 |
+
if a is None:
|
764 |
+
print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
|
765 |
+
if a > 1000 or a < 0.0001:
|
766 |
+
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
|
767 |
+
b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
|
768 |
+
torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR)) *
|
769 |
+
torch.mean(grad_fake_OCR))
|
770 |
+
# self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
|
771 |
+
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
|
772 |
+
# self.loss_w_fake = a0.detach() * self.loss_w_fake
|
773 |
+
|
774 |
+
self.loss_T = (1 - 1 * self.opt.onlyOCR) * self.loss_G_ + self.loss_OCR_fake # + self.loss_w_fake
|
775 |
+
self.loss_T.backward(retain_graph=True)
|
776 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
777 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
|
778 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
779 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
780 |
+
with torch.no_grad():
|
781 |
+
self.loss_T.backward()
|
782 |
+
else:
|
783 |
+
self.loss_T.backward()
|
784 |
+
|
785 |
+
if self.opt.clip_grad > 0:
|
786 |
+
clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
|
787 |
+
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
|
788 |
+
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
|
789 |
+
sys.exit()
|
790 |
+
|
791 |
+
def optimize_D_OCR(self):
|
792 |
+
self.forward()
|
793 |
+
self.set_requires_grad([self.netD], True)
|
794 |
+
self.set_requires_grad([self.netOCR], True)
|
795 |
+
self.optimizer_D.zero_grad()
|
796 |
+
# if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
797 |
+
self.optimizer_OCR.zero_grad()
|
798 |
+
self.backward_D_OCR()
|
799 |
+
|
800 |
+
def optimize_D_OCR_step(self):
|
801 |
+
self.optimizer_D.step()
|
802 |
+
|
803 |
+
self.optimizer_OCR.step()
|
804 |
+
self.optimizer_D.zero_grad()
|
805 |
+
self.optimizer_OCR.zero_grad()
|
806 |
+
|
807 |
+
def optimize_G_WL(self):
|
808 |
+
self.forward()
|
809 |
+
self.set_requires_grad([self.netD], False)
|
810 |
+
self.set_requires_grad([self.netOCR], False)
|
811 |
+
self.set_requires_grad([self.netW], False)
|
812 |
+
self.backward_G_WL()
|
813 |
+
|
814 |
+
def optimize_G_only(self):
|
815 |
+
self.forward()
|
816 |
+
self.set_requires_grad([self.netD], False)
|
817 |
+
self.set_requires_grad([self.netOCR], False)
|
818 |
+
self.set_requires_grad([self.netW], False)
|
819 |
+
self.backward_G_only()
|
820 |
+
|
821 |
+
def optimize_G_step(self):
|
822 |
+
self.optimizer_G.step()
|
823 |
+
self.optimizer_G.zero_grad()
|
824 |
+
|
825 |
+
def save_networks(self, epoch, save_dir):
|
826 |
+
"""Save all the networks to the disk.
|
827 |
+
|
828 |
+
Parameters:
|
829 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
830 |
+
"""
|
831 |
+
for name in self.model_names:
|
832 |
+
if isinstance(name, str):
|
833 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
834 |
+
save_path = os.path.join(save_dir, save_filename)
|
835 |
+
net = getattr(self, 'net' + name)
|
836 |
+
|
837 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
838 |
+
# torch.save(net.module.cpu().state_dict(), save_path)
|
839 |
+
if len(self.gpu_ids) > 1:
|
840 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
841 |
+
else:
|
842 |
+
torch.save(net.cpu().state_dict(), save_path)
|
843 |
+
net.cuda(self.gpu_ids[0])
|
844 |
+
else:
|
845 |
+
torch.save(net.cpu().state_dict(), save_path)
|
846 |
+
|
847 |
+
def compute_d_scores(self, data_loader: torch.utils.data.DataLoader, amount: int = None):
|
848 |
+
scores = []
|
849 |
+
words = []
|
850 |
+
amount = len(data_loader) if amount is None else amount // data_loader.batch_size
|
851 |
+
|
852 |
+
with torch.no_grad():
|
853 |
+
for i in range(amount):
|
854 |
+
data = next(iter(data_loader))
|
855 |
+
words.extend([d.decode() for d in data['label']])
|
856 |
+
scores.extend(list(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy()))
|
857 |
+
|
858 |
+
return scores, words
|
859 |
+
|
860 |
+
def compute_d_scores_fake(self, data_loader: torch.utils.data.DataLoader, amount: int = None):
|
861 |
+
scores = []
|
862 |
+
words = []
|
863 |
+
amount = len(data_loader) if amount is None else amount // data_loader.batch_size
|
864 |
+
|
865 |
+
with torch.no_grad():
|
866 |
+
for i in range(amount):
|
867 |
+
data = next(iter(data_loader))
|
868 |
+
to_generate = [self.word_generator.generate().encode('utf-8') for _ in range(data_loader.batch_size)]
|
869 |
+
text_encode_fake, len_text_fake, encode_pos_fake = self.netconverter.encode(to_generate)
|
870 |
+
fake, _ = self.netG(data['simg'].to(self.args.device), text_encode_fake.to(self.args.device))
|
871 |
+
|
872 |
+
words.extend([d.decode() for d in to_generate])
|
873 |
+
scores.extend(list(self.netD(fake).squeeze().detach().cpu().numpy()))
|
874 |
+
|
875 |
+
return scores, words
|
876 |
+
|
877 |
+
def compute_d_stats(self, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader):
|
878 |
+
train_values = []
|
879 |
+
val_values = []
|
880 |
+
fake_values = []
|
881 |
+
with torch.no_grad():
|
882 |
+
for i in range(self.rv_sample_size // train_loader.batch_size):
|
883 |
+
data = next(iter(train_loader))
|
884 |
+
train_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())
|
885 |
+
|
886 |
+
for i in range(self.rv_sample_size // val_loader.batch_size):
|
887 |
+
data = next(iter(val_loader))
|
888 |
+
val_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())
|
889 |
+
|
890 |
+
for i in range(self.rv_sample_size):
|
891 |
+
data = self.last_fakes[i]
|
892 |
+
fake_values.append(self.netD(data.unsqueeze(0)).squeeze().detach().cpu().numpy())
|
893 |
+
|
894 |
+
return np.mean(train_values), np.mean(val_values), np.mean(fake_values)
|
models/networks.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
from util.util import to_device, load_network
|
7 |
+
|
8 |
+
###############################################################################
|
9 |
+
# Helper Functions
|
10 |
+
###############################################################################
|
11 |
+
|
12 |
+
|
13 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
14 |
+
"""Initialize network weights.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
net (network) -- network to be initialized
|
18 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
19 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
20 |
+
|
21 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
22 |
+
work better for some applications. Feel free to try yourself.
|
23 |
+
"""
|
24 |
+
def init_func(m): # define the initialization function
|
25 |
+
classname = m.__class__.__name__
|
26 |
+
if (isinstance(m, nn.Conv2d)
|
27 |
+
or isinstance(m, nn.Linear)
|
28 |
+
or isinstance(m, nn.Embedding)):
|
29 |
+
# if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
30 |
+
if init_type == 'N02':
|
31 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
32 |
+
elif init_type in ['glorot', 'xavier']:
|
33 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
34 |
+
elif init_type == 'kaiming':
|
35 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
36 |
+
elif init_type == 'ortho':
|
37 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
38 |
+
else:
|
39 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
40 |
+
# if hasattr(m, 'bias') and m.bias is not None:
|
41 |
+
# init.constant_(m.bias.data, 0.0)
|
42 |
+
# elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
43 |
+
# init.normal_(m.weight.data, 1.0, init_gain)
|
44 |
+
# init.constant_(m.bias.data, 0.0)
|
45 |
+
if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']:
|
46 |
+
# print('initialize network with %s' % init_type)
|
47 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
48 |
+
else:
|
49 |
+
# print('loading the model from %s' % init_type)
|
50 |
+
net = load_network(net, init_type, 'latest')
|
51 |
+
return net
|
52 |
+
|
53 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
54 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
55 |
+
Parameters:
|
56 |
+
net (network) -- the network to be initialized
|
57 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
58 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
59 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
60 |
+
|
61 |
+
Return an initialized network.
|
62 |
+
"""
|
63 |
+
if len(gpu_ids) > 0:
|
64 |
+
assert(torch.cuda.is_available())
|
65 |
+
net.to(gpu_ids[0])
|
66 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
67 |
+
init_weights(net, init_type, init_gain=init_gain)
|
68 |
+
return net
|
69 |
+
|
70 |
+
|
71 |
+
def get_scheduler(optimizer, opt):
|
72 |
+
"""Return a learning rate scheduler
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
optimizer -- the optimizer of the network
|
76 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
77 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
78 |
+
|
79 |
+
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
|
80 |
+
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
|
81 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
82 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
83 |
+
"""
|
84 |
+
if opt.lr_policy == 'linear':
|
85 |
+
def lambda_rule(epoch):
|
86 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
87 |
+
return lr_l
|
88 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
89 |
+
elif opt.lr_policy == 'step':
|
90 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
91 |
+
elif opt.lr_policy == 'plateau':
|
92 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
93 |
+
elif opt.lr_policy == 'cosine':
|
94 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
95 |
+
else:
|
96 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
97 |
+
return scheduler
|
98 |
+
|
models/positional_encodings.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
def get_emb(sin_inp):
|
7 |
+
"""
|
8 |
+
Gets a base embedding for one dimension with sin and cos intertwined
|
9 |
+
"""
|
10 |
+
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
|
11 |
+
return torch.flatten(emb, -2, -1)
|
12 |
+
|
13 |
+
|
14 |
+
class PositionalEncoding1D(nn.Module):
|
15 |
+
def __init__(self, channels):
|
16 |
+
"""
|
17 |
+
:param channels: The last dimension of the tensor you want to apply pos emb to.
|
18 |
+
"""
|
19 |
+
super(PositionalEncoding1D, self).__init__()
|
20 |
+
self.org_channels = channels
|
21 |
+
channels = int(np.ceil(channels / 2) * 2)
|
22 |
+
self.channels = channels
|
23 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
|
24 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
25 |
+
self.cached_penc = None
|
26 |
+
|
27 |
+
def forward(self, tensor):
|
28 |
+
"""
|
29 |
+
:param tensor: A 3d tensor of size (batch_size, x, ch)
|
30 |
+
:return: Positional Encoding Matrix of size (batch_size, x, ch)
|
31 |
+
"""
|
32 |
+
if len(tensor.shape) != 3:
|
33 |
+
raise RuntimeError("The input tensor has to be 3d!")
|
34 |
+
|
35 |
+
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
|
36 |
+
return self.cached_penc
|
37 |
+
|
38 |
+
self.cached_penc = None
|
39 |
+
batch_size, x, orig_ch = tensor.shape
|
40 |
+
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
|
41 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
42 |
+
emb_x = get_emb(sin_inp_x)
|
43 |
+
emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
|
44 |
+
emb[:, : self.channels] = emb_x
|
45 |
+
|
46 |
+
self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
|
47 |
+
return self.cached_penc
|
48 |
+
|
49 |
+
|
50 |
+
class PositionalEncodingPermute1D(nn.Module):
|
51 |
+
def __init__(self, channels):
|
52 |
+
"""
|
53 |
+
Accepts (batchsize, ch, x) instead of (batchsize, x, ch)
|
54 |
+
"""
|
55 |
+
super(PositionalEncodingPermute1D, self).__init__()
|
56 |
+
self.penc = PositionalEncoding1D(channels)
|
57 |
+
|
58 |
+
def forward(self, tensor):
|
59 |
+
tensor = tensor.permute(0, 2, 1)
|
60 |
+
enc = self.penc(tensor)
|
61 |
+
return enc.permute(0, 2, 1)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def org_channels(self):
|
65 |
+
return self.penc.org_channels
|
66 |
+
|
67 |
+
|
68 |
+
class PositionalEncoding2D(nn.Module):
|
69 |
+
def __init__(self, channels):
|
70 |
+
"""
|
71 |
+
:param channels: The last dimension of the tensor you want to apply pos emb to.
|
72 |
+
"""
|
73 |
+
super(PositionalEncoding2D, self).__init__()
|
74 |
+
self.org_channels = channels
|
75 |
+
channels = int(np.ceil(channels / 4) * 2)
|
76 |
+
self.channels = channels
|
77 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
|
78 |
+
self.register_buffer("inv_freq", inv_freq)
|
79 |
+
self.cached_penc = None
|
80 |
+
|
81 |
+
def forward(self, tensor):
|
82 |
+
"""
|
83 |
+
:param tensor: A 4d tensor of size (batch_size, x, y, ch)
|
84 |
+
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
|
85 |
+
"""
|
86 |
+
if len(tensor.shape) != 4:
|
87 |
+
raise RuntimeError("The input tensor has to be 4d!")
|
88 |
+
|
89 |
+
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
|
90 |
+
return self.cached_penc
|
91 |
+
|
92 |
+
self.cached_penc = None
|
93 |
+
batch_size, x, y, orig_ch = tensor.shape
|
94 |
+
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
|
95 |
+
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
|
96 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
97 |
+
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
|
98 |
+
emb_x = get_emb(sin_inp_x).unsqueeze(1)
|
99 |
+
emb_y = get_emb(sin_inp_y)
|
100 |
+
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
|
101 |
+
tensor.type()
|
102 |
+
)
|
103 |
+
emb[:, :, : self.channels] = emb_x
|
104 |
+
emb[:, :, self.channels : 2 * self.channels] = emb_y
|
105 |
+
|
106 |
+
self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
|
107 |
+
return self.cached_penc
|
108 |
+
|
109 |
+
|
110 |
+
class PositionalEncodingPermute2D(nn.Module):
|
111 |
+
def __init__(self, channels):
|
112 |
+
"""
|
113 |
+
Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch)
|
114 |
+
"""
|
115 |
+
super(PositionalEncodingPermute2D, self).__init__()
|
116 |
+
self.penc = PositionalEncoding2D(channels)
|
117 |
+
|
118 |
+
def forward(self, tensor):
|
119 |
+
tensor = tensor.permute(0, 2, 3, 1)
|
120 |
+
enc = self.penc(tensor)
|
121 |
+
return enc.permute(0, 3, 1, 2)
|
122 |
+
|
123 |
+
@property
|
124 |
+
def org_channels(self):
|
125 |
+
return self.penc.org_channels
|
126 |
+
|
127 |
+
|
128 |
+
class PositionalEncoding3D(nn.Module):
|
129 |
+
def __init__(self, channels):
|
130 |
+
"""
|
131 |
+
:param channels: The last dimension of the tensor you want to apply pos emb to.
|
132 |
+
"""
|
133 |
+
super(PositionalEncoding3D, self).__init__()
|
134 |
+
self.org_channels = channels
|
135 |
+
channels = int(np.ceil(channels / 6) * 2)
|
136 |
+
if channels % 2:
|
137 |
+
channels += 1
|
138 |
+
self.channels = channels
|
139 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
|
140 |
+
self.register_buffer("inv_freq", inv_freq)
|
141 |
+
self.cached_penc = None
|
142 |
+
|
143 |
+
def forward(self, tensor):
|
144 |
+
"""
|
145 |
+
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
|
146 |
+
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
|
147 |
+
"""
|
148 |
+
if len(tensor.shape) != 5:
|
149 |
+
raise RuntimeError("The input tensor has to be 5d!")
|
150 |
+
|
151 |
+
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
|
152 |
+
return self.cached_penc
|
153 |
+
|
154 |
+
self.cached_penc = None
|
155 |
+
batch_size, x, y, z, orig_ch = tensor.shape
|
156 |
+
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
|
157 |
+
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
|
158 |
+
pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
|
159 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
160 |
+
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
|
161 |
+
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
|
162 |
+
emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
|
163 |
+
emb_y = get_emb(sin_inp_y).unsqueeze(1)
|
164 |
+
emb_z = get_emb(sin_inp_z)
|
165 |
+
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
|
166 |
+
tensor.type()
|
167 |
+
)
|
168 |
+
emb[:, :, :, : self.channels] = emb_x
|
169 |
+
emb[:, :, :, self.channels : 2 * self.channels] = emb_y
|
170 |
+
emb[:, :, :, 2 * self.channels :] = emb_z
|
171 |
+
|
172 |
+
self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)
|
173 |
+
return self.cached_penc
|
174 |
+
|
175 |
+
|
176 |
+
class PositionalEncodingPermute3D(nn.Module):
|
177 |
+
def __init__(self, channels):
|
178 |
+
"""
|
179 |
+
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
|
180 |
+
"""
|
181 |
+
super(PositionalEncodingPermute3D, self).__init__()
|
182 |
+
self.penc = PositionalEncoding3D(channels)
|
183 |
+
|
184 |
+
def forward(self, tensor):
|
185 |
+
tensor = tensor.permute(0, 2, 3, 4, 1)
|
186 |
+
enc = self.penc(tensor)
|
187 |
+
return enc.permute(0, 4, 1, 2, 3)
|
188 |
+
|
189 |
+
@property
|
190 |
+
def org_channels(self):
|
191 |
+
return self.penc.org_channels
|
192 |
+
|
193 |
+
|
194 |
+
class Summer(nn.Module):
|
195 |
+
def __init__(self, penc):
|
196 |
+
"""
|
197 |
+
:param model: The type of positional encoding to run the summer on.
|
198 |
+
"""
|
199 |
+
super(Summer, self).__init__()
|
200 |
+
self.penc = penc
|
201 |
+
|
202 |
+
def forward(self, tensor):
|
203 |
+
"""
|
204 |
+
:param tensor: A 3, 4 or 5d tensor that matches the model output size
|
205 |
+
:return: Positional Encoding Matrix summed to the original tensor
|
206 |
+
"""
|
207 |
+
penc = self.penc(tensor)
|
208 |
+
assert (
|
209 |
+
tensor.size() == penc.size()
|
210 |
+
), "The original tensor size {} and the positional encoding tensor size {} must match!".format(
|
211 |
+
tensor.size(), penc.size()
|
212 |
+
)
|
213 |
+
return tensor + penc
|
214 |
+
|
215 |
+
|
216 |
+
class SparsePositionalEncoding2D(PositionalEncoding2D):
|
217 |
+
def __init__(self, channels, x, y, device='cuda'):
|
218 |
+
super(SparsePositionalEncoding2D, self).__init__(channels)
|
219 |
+
self.y, self.x = y, x
|
220 |
+
self.fake_tensor = torch.zeros((1, x, y, channels), device=device)
|
221 |
+
|
222 |
+
def forward(self, coords):
|
223 |
+
"""
|
224 |
+
:param coords: A list of list of coordinates (((x1, y1), (x2, y22), ... ), ... )
|
225 |
+
:return: Positional Encoding Matrix summed to the original tensor
|
226 |
+
"""
|
227 |
+
encodings = super().forward(self.fake_tensor)
|
228 |
+
encodings = encodings.permute(0, 3, 1, 2)
|
229 |
+
indices = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(c) for c in coords], batch_first=True, padding_value=-1)
|
230 |
+
indices = indices.unsqueeze(0).to(self.fake_tensor.device)
|
231 |
+
assert self.x == self.y
|
232 |
+
indices = (indices + 0.5) / self.x * 2 - 1
|
233 |
+
indices = torch.flip(indices, (-1, ))
|
234 |
+
return torch.nn.functional.grid_sample(encodings, indices).squeeze().permute(2, 1, 0)
|
235 |
+
|
236 |
+
# all_encodings = []
|
237 |
+
# for coords_row in coords:
|
238 |
+
# res_encodings = []
|
239 |
+
# for xy in coords_row:
|
240 |
+
# if xy is None:
|
241 |
+
# res_encodings.append(padding)
|
242 |
+
# else:
|
243 |
+
# x, y = xy
|
244 |
+
# res_encodings.append(encodings[x, y, :])
|
245 |
+
# all_encodings.append(res_encodings)
|
246 |
+
# return torch.stack(res_encodings).to(self.fake_tensor.device)
|
247 |
+
|
248 |
+
# coords = torch.Tensor(coords).to(self.fake_tensor.device).long()
|
249 |
+
# assert torch.all(coords[:, 0] < self.x)
|
250 |
+
# assert torch.all(coords[:, 1] < self.y)
|
251 |
+
# coords = coords[:, 0] + (coords[:, 1] * self.x)
|
252 |
+
# encodings = super().forward(self.fake_tensor).reshape((-1, self.org_channels))
|
253 |
+
# return encodings[coords]
|
254 |
+
|
255 |
+
if __name__ == '__main__':
|
256 |
+
pos = SparsePositionalEncoding2D(10, 10, 20)
|
257 |
+
pos([[0, 0], [0, 9], [1, 0], [9, 15]])
|