import collections import dataclasses import types import pytorch_lightning as pl import torch.utils.data import transformers from data import ( generate_annotated_images, get_annotation_ground_truth_str, DataItem, get_extra_tokens, Batch, Split, BatchCollateFunction, ) from utils import load_pickle_or_build_object_and_save @dataclasses.dataclass class Model: processor: transformers.models.donut.processing_donut.DonutProcessor tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast encoder_decoder: transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder.VisionEncoderDecoderModel batch_collate_function: BatchCollateFunction config: types.SimpleNamespace def add_unknown_tokens_to_tokenizer( tokenizer, encoder_decoder, unknown_tokens: list[str] ): tokenizer.add_tokens(unknown_tokens) encoder_decoder.decoder.resize_token_embeddings(len(tokenizer)) def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter: unknown_tokens_counter = collections.Counter() for annotated_image in generate_annotated_images(): ground_truth = get_annotation_ground_truth_str(annotated_image.annotation) input_ids = tokenizer(ground_truth).input_ids tokens = tokenizer.tokenize(ground_truth, add_special_tokens=True) for token_id, token in zip(input_ids, tokens, strict=True): if token_id == tokenizer.unk_token_id: unknown_tokens_counter.update([token]) return unknown_tokens_counter def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation( tokenizer, token_ids ): token_ids[token_ids == tokenizer.pad_token_id] = -100 return token_ids @dataclasses.dataclass class BatchCollateFunction: processor: transformers.models.donut.processing_donut.DonutProcessor tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast decoder_sequence_max_length: int def __call__(self, batch: list[DataItem], split: Split) -> Batch: images = [di.image for di in batch] images = self.processor( images, random_padding=split == Split.train, return_tensors="pt" ).pixel_values target_token_ids = self.tokenizer( [di.target_string for di in batch], add_special_tokens=False, max_length=self.decoder_sequence_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation( self.tokenizer, target_token_ids ) data_indices = [di.data_index for di in batch] return Batch(images=images, labels=labels, data_indices=data_indices) def build_model(config: types.SimpleNamespace or object) -> Model: donut_processor = transformers.DonutProcessor.from_pretrained( config.pretrained_model_name ) donut_processor.image_processor.size = dict( width=config.image_width, height=config.image_height ) donut_processor.image_processor.do_align_long_axis = False tokenizer = donut_processor.tokenizer encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained( config.pretrained_model_name ) encoder_decoder_config.encoder.image_size = ( config.image_width, config.image_height, ) encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained( config.pretrained_model_name, config=encoder_decoder_config ) encoder_decoder_config.pad_token_id = tokenizer.pad_token_id encoder_decoder_config.decoder_start_token_id = tokenizer.convert_tokens_to_ids( get_extra_tokens().benetech_prompt ) encoder_decoder_config.bos_token_id = encoder_decoder_config.decoder_start_token_id encoder_decoder_config.eos_token_id = tokenizer.convert_tokens_to_ids( get_extra_tokens().benetech_prompt_end ) extra_tokens = list(get_extra_tokens().__dict__.values()) add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, extra_tokens) unknown_dataset_tokens = load_pickle_or_build_object_and_save( config.unknown_tokens_for_tokenizer_path, lambda: list(find_unknown_tokens_for_tokenizer(tokenizer).keys()), ) add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, unknown_dataset_tokens) tokenizer.eos_token_id = encoder_decoder_config.eos_token_id batch_collate_function = BatchCollateFunction( processor=donut_processor, tokenizer=tokenizer, decoder_sequence_max_length=config.decoder_sequence_max_length, ) return Model( processor=donut_processor, tokenizer=tokenizer, encoder_decoder=encoder_decoder, batch_collate_function=batch_collate_function, config=config, ) def generate_token_strings( model: Model, images: torch.Tensor, skip_special_tokens=True ) -> list[str]: decoder_output = model.encoder_decoder.generate( images, max_length=10 if model.config.debug else model.config.decoder_sequence_max_length, eos_token_id=model.tokenizer.eos_token_id, return_dict_in_generate=True, ) return model.tokenizer.batch_decode( decoder_output.sequences, skip_special_tokens=skip_special_tokens ) def predict_string(image, model: Model): image = model.processor( image, random_padding=False, return_tensors="pt" ).pixel_values string = generate_token_strings(model, image)[0] return string class LightningModule(pl.LightningModule): def __init__(self, config): super().__init__() self.save_hyperparameters() self.model = build_model(config) self.encoder_decoder = self.model.encoder_decoder def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor: loss = self.compute_loss(batch) self.log("train_loss", loss) return loss def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0): loss = self.compute_loss(batch) self.log("val_loss", loss) def compute_loss(self, batch: Batch) -> torch.Tensor: outputs = self.encoder_decoder(pixel_values=batch.images, labels=batch.labels) loss = outputs.loss return loss def configure_optimizers(self) -> torch.optim.Optimizer: optimizer = torch.optim.Adam( self.parameters(), lr=self.hparams["config"].learning_rate ) return optimizer