MikkoLipsanen commited on
Commit
1838a16
1 Parent(s): 8c07e0b

Upload 3 files

Browse files
Files changed (3) hide show
  1. augments.py +42 -0
  2. dataset.py +44 -0
  3. train_trocr.py +142 -0
augments.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import torchvision.transforms as T
4
+ import numpy as np
5
+
6
+ class RandAug:
7
+ """Randomly chosen image augmentations."""
8
+
9
+ def __init__(self):
10
+ # Augmentation options
11
+ self.trans = ['rotation', 'blur', 'color', 'sharpness']
12
+
13
+ def __call__(self, img):
14
+ # Randomly choose the number of augmentations used for input image
15
+ n_transforms = random.randint(1, len(self.trans))
16
+ # Randomly choose the augmentation types
17
+ transforms = random.sample(self.trans, n_transforms)
18
+
19
+ # Implement the augmentations sequentially
20
+ if 'rotation' in transforms:
21
+ rotation = random.randint(-10, 10)
22
+ img = T.functional.rotate(img=img, angle=rotation, expand=True, fill=255)
23
+
24
+ if 'blur' in transforms:
25
+ kernel = random.choice([1,3,5])
26
+ transform = T.GaussianBlur(kernel, sigma=(0.1, 2.0))
27
+ img = transform(img)
28
+
29
+ if 'color' in transforms:
30
+ rand_brightness = random.uniform(0, 0.3)
31
+ rand_hue = random.uniform(0, 0.5)
32
+ rand_contrast = random.uniform(0, 0.5)
33
+ rand_saturation = random.uniform(0, 0.5)
34
+ transform = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
35
+ img = transform(img)
36
+
37
+ if 'sharpness' in transforms:
38
+ sharpness = 1+(np.random.exponential()/2)
39
+ trans = T.RandomAdjustSharpness(sharpness, p=1)
40
+ img = transform(img)
41
+
42
+ return img
dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+ from augments import RandAug
8
+
9
+ # Torch dataset
10
+ class TextlineDataset(Dataset):
11
+ def __init__(self, root_dir, df, processor, augment=False, max_target_length=128):
12
+ self.root_dir = root_dir
13
+ self.df = df
14
+ self.processor = processor
15
+ self.augment = augment
16
+ self.augmentator = RandAug()
17
+ self.max_target_length = max_target_length
18
+
19
+ def __len__(self):
20
+ return len(self.df)
21
+
22
+ def __getitem__(self, idx):
23
+ # get file name + text
24
+ file_name = self.df['file_name'][idx]
25
+ text = self.df['text'][idx]
26
+
27
+ # prepare image (i.e. resize + normalize)
28
+ image = Image.open(self.root_dir + file_name).convert("RGB")
29
+
30
+ # Add image augmentations
31
+ if self.augment:
32
+ image = self.augmentator(image)
33
+
34
+ # extract the pixel values
35
+ pixel_values = self.processor(image, return_tensors="pt").pixel_values
36
+
37
+ # add labels (input_ids) by encoding the text
38
+ labels = self.processor.tokenizer(str(text),
39
+ padding="max_length", truncation=True,
40
+ max_length=self.max_target_length).input_ids
41
+ # important: make sure that PAD tokens are ignored by the loss function
42
+ labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
43
+ encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
44
+ return encoding
train_trocr.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from PIL import Image
4
+ import argparse
5
+ from evaluate import load
6
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
7
+ import torchvision.transforms as transforms
8
+ #import torch_optimizer as optim
9
+
10
+ from dataset import TextlineDataset
11
+
12
+ parser = argparse.ArgumentParser('arguments for the code')
13
+
14
+ parser.add_argument('--root_path', type=str, default="",
15
+ help='Root path to data files.')
16
+ parser.add_argument('--tr_data_path', type=str, default="/data/htr/trocr_data/trocr_tuomiokirjat/train/trocr/data.csv",
17
+ help='Path to .csv file containing the training data.')
18
+ parser.add_argument('--val_data_path', type=str, default="/data/htr/trocr_data/trocr_tuomiokirjat/val/trocr/data.csv",
19
+ help='Path to .csv file containing the validation data.')
20
+ parser.add_argument('--output_path', type=str, default="/koodit/htr/text_recognition/trocr/tuomiokirjat/models/22112023/",
21
+ help='Path for saving training results.')
22
+ parser.add_argument('--resume_path', type=str, default="/koodit/htr/text_recognition/trocr/tuomiokirjat/models/22112023",
23
+ help='Path to the previous model')
24
+ parser.add_argument('--batch_size', type=int, default=24,
25
+ help='Batch size per device.')
26
+ parser.add_argument('--epochs', type=int, default=13,
27
+ help='Number of training epochs.')
28
+
29
+ args = parser.parse_args()
30
+
31
+ # nohup python train_trocr.py > logs/tuomiokirjat_resume_23112023.txt 2>&1 &
32
+ # echo $! > logs/save_pid.txt
33
+
34
+ # run using 2 GPUs: torchrun --nproc_per_node=2 train_trocr.py > logs/tuomiokirjat_22112023.txt 2>&1 &
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ print('Device: ', device)
38
+
39
+ # Initialize processor and model
40
+ #processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
41
+ #model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
42
+ processor =TrOCRProcessor.from_pretrained(args.resume_path + "/processor")
43
+ model = VisionEncoderDecoderModel.from_pretrained(args.resume_path + "/checkpoint-13094")
44
+
45
+ model.to(device)
46
+
47
+ # Initialize metrics
48
+ cer_metric = load("cer")
49
+ wer_metric = load("wer")
50
+
51
+ # Load train and validation data to dataframes
52
+ train_df = pd.read_csv(args.tr_data_path)
53
+ val_df = pd.read_csv(args.val_data_path)
54
+ #train_df = train_df.iloc[:50]
55
+ #val_df = val_df.iloc[:10]
56
+
57
+ # Reset the indices to start from zero
58
+ train_df.reset_index(drop=True, inplace=True)
59
+ val_df.reset_index(drop=True, inplace=True)
60
+
61
+ # Create train and validation datasets
62
+ train_dataset = TextlineDataset(root_dir=args.root_path,
63
+ df=train_df,
64
+ processor=processor,
65
+ augment=False)
66
+
67
+ eval_dataset = TextlineDataset(root_dir=args.root_path,
68
+ df=val_df,
69
+ processor=processor,
70
+ augment=False)
71
+
72
+ print("Number of training examples:", len(train_dataset))
73
+ print("Number of validation examples:", len(eval_dataset))
74
+
75
+ # Define model configuration
76
+
77
+ # set special tokens used for creating the decoder_input_ids from the labels
78
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
79
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
80
+ # make sure vocab size is set correctly
81
+ model.config.vocab_size = model.config.decoder.vocab_size
82
+ # set beam search parameters
83
+ model.config.eos_token_id = processor.tokenizer.sep_token_id
84
+ model.config.max_length = 64
85
+ model.config.early_stopping = True
86
+ model.config.no_repeat_ngram_size = 3
87
+ model.config.length_penalty = 2.0
88
+ model.config.num_beams = 4
89
+
90
+ # Set arguments for model training
91
+ # For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
92
+ training_args = Seq2SeqTrainingArguments(
93
+ predict_with_generate=True,
94
+ evaluation_strategy="epoch",
95
+ save_strategy="epoch",
96
+ logging_strategy="steps",
97
+ logging_steps=50,
98
+ per_device_train_batch_size=args.batch_size,
99
+ per_device_eval_batch_size=args.batch_size,
100
+ load_best_model_at_end=True,
101
+ metric_for_best_model='cer',
102
+ greater_is_better=False,
103
+ #fp16=True,
104
+ num_train_epochs=args.epochs,
105
+ save_total_limit=2,
106
+ output_dir=args.output_path,
107
+ optim="adamw_torch"
108
+ )
109
+
110
+ # Function for computing CER and WER metrics for the prediction results
111
+ def compute_metrics(pred):
112
+ labels_ids = pred.label_ids
113
+ pred_ids = pred.predictions
114
+
115
+ pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
116
+ labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
117
+ label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
118
+
119
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
120
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
121
+
122
+ return {"cer": cer, "wer": wer}
123
+
124
+
125
+ # instantiate trainer
126
+ # For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
127
+ trainer = Seq2SeqTrainer(
128
+ model=model,
129
+ tokenizer=processor.feature_extractor,
130
+ args=training_args,
131
+ compute_metrics=compute_metrics,
132
+ train_dataset=train_dataset,
133
+ eval_dataset=eval_dataset,
134
+ data_collator=default_data_collator,
135
+ )
136
+
137
+ # Train the model
138
+ trainer.train()
139
+ #trainer.train(resume_from_checkpoint = True)
140
+ model.save_pretrained(args.output_path)
141
+ processor.save_pretrained(args.output_path + "/processor")
142
+