File size: 7,537 Bytes
9c909e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
from dataclasses import dataclass, field
from typing import Optional
import pandas as pd
import os
import torch
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model
from data import AphaPenDataset
import evaluate
from sklearn.model_selection import train_test_split
from src.calibrator import EncoderDecoderCalibrator
from src.loss import MarginLoss, KLRegularization
from src.similarity import CERSimilarity
from datetime import datetime
import torch.nn.functional as F
os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR"
# # Step 1: Load the dataset
train_df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv"
test_df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
#train_df = pd.read_csv(train_df_path)
#train_df.dropna(inplace=True)
train_df = pd.read_csv(test_df_path)[:4000]
train_df.dropna(inplace=True)
test_df = pd.read_csv(test_df_path)[4000:]
test_df.dropna(inplace=True)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_"
model_name = "microsoft/trocr-large-handwritten"
processor = TrOCRProcessor.from_pretrained(model_name)
train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor)
eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor)
# Step 2: Load the model
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# for peft
model.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
# LoRa
lora_config = LoraConfig(
r=1,
lora_alpha=8,
lora_dropout=0.1,
target_modules=[
'query',
'key',
'value',
'intermediate.dense',
'output.dense',
#'wte',
#'wpe',
#'c_attn',
#'c_proj',
#'q_attn',
#'c_fc'
],
)
model = get_peft_model(model, lora_config)
tokenizer = processor.tokenizer
# sim = CERSimilarity(tokenizer)
# loss = MarginLoss(sim, beta=0.1, num_samples=60)
# reg = KLRegularization(model)
# calibrator = EncoderDecoderCalibrator(model, loss, reg, 15, 15)
# # Step 3: Define the training arguments
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="steps",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
bf16=True,
bf16_full_eval=True,
output_dir="./",
logging_steps=100,
save_steps=20000,
eval_steps=500,
# report_to="wandb",
optim="adamw_torch_fused",
lr_scheduler_type="cosine",
gradient_accumulation_steps=2,
learning_rate=1.0e-4,
max_steps=10000,
run_name=f"trocr-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}",
)
# Step 4: Define a metric
cer_metric = evaluate.load("cer")
def compute_cer(pred, target):
return cer_metric.compute(predictions=[pred], references=[target])['cer']
def generate_candidates(model, pixel_values, num_candidates=10):
return model.generate(
pixel_values,
num_return_sequences=num_candidates,
num_beams=num_candidates,
output_scores=True,
return_dict_in_generate=True
)
def rank_loss(positive_scores, negative_scores):
return F.relu(1 - positive_scores + negative_scores).mean()
def margin_loss(positive_scores, negative_scores, margin=0.1):
return F.relu(margin - positive_scores + negative_scores).mean()
def calibration_loss(model, pixel_values, ground_truth, processor, loss_type='margin'):
candidates = generate_candidates(model, pixel_values)
candidate_sequences = processor.batch_decode(candidates.sequences, skip_special_tokens=True)
ground_truth = processor.decode(ground_truth, skip_special_tokens=True)
similarities = [1 - compute_cer(cand, ground_truth) for cand in candidate_sequences]
positive_pairs = []
negative_pairs = []
for i in range(len(similarities)):
for j in range(i + 1, len(similarities)):
if similarities[i] > similarities[j]:
positive_pairs.append((i, j))
else:
negative_pairs.append((i, j))
if not positive_pairs or not negative_pairs:
return torch.tensor(0.0, device=pixel_values.device)
positive_scores = candidates.sequences_scores[torch.tensor(positive_pairs)[:, 0]]
negative_scores = candidates.sequences_scores[torch.tensor(negative_pairs)[:, 1]]
if loss_type == 'rank':
return rank_loss(positive_scores, negative_scores)
elif loss_type == 'margin':
return margin_loss(positive_scores, negative_scores)
else:
raise ValueError("Invalid loss type. Choose 'rank' or 'margin'.")
class CalibratedTrainer(Seq2SeqTrainer):
def __init__(self, *args, **kwargs):
self.processor = kwargs.pop('processor', None)
self.calibration_weight = kwargs.pop('calibration_weight', 0.1)
self.calibration_loss_type = kwargs.pop('calibration_loss_type', 'margin')
super().__init__(*args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
pixel_values = inputs['pixel_values']
outputs = model.generate(**inputs, return_dict_in_generate=True, output_logits=True)
logits = outputs.logits
print(logits)
# Original cross-entropy loss
ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
# Calibration loss
cal_loss = calibration_loss(model, pixel_values, labels, self.processor, self.calibration_loss_type)
# Combine losses
total_loss = ce_loss + self.calibration_weight * cal_loss
return (total_loss, outputs) if return_outputs else total_loss
def compute_metrics(pred):
# accuracy_metric = evaluate.load("precision")
cer_metric = evaluate.load("cer")
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
# accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist())
return {"cer": cer}
# # Step 5: Define the Trainer
# Step 5: Define the Trainer
trainer = CalibratedTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
# compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
processor=processor,
calibration_weight=0.1,
calibration_loss_type='margin' # or 'rank'
)
trainer.train() |