|
import warnings |
|
|
|
import torch |
|
|
|
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") |
|
|
|
import logging |
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
|
|
from typing import List, Optional, Tuple |
|
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM |
|
from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig |
|
from surya.model.recognition.encoder import VariableDonutSwinModel |
|
from surya.model.recognition.decoder import MBartMoE |
|
from surya.settings import settings |
|
|
|
|
|
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None): |
|
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint) |
|
|
|
|
|
if langs is not None: |
|
config.decoder.langs = {lang_iso : lang_int for lang_iso, lang_int in config.decoder.langs.items() if lang_int in langs} |
|
|
|
decoder_config = vars(config.decoder) |
|
decoder = MBartMoEConfig(**decoder_config) |
|
config.decoder = decoder |
|
|
|
encoder_config = vars(config.encoder) |
|
encoder = VariableDonutSwinConfig(**encoder_config) |
|
config.encoder = encoder |
|
|
|
|
|
AutoModel.register(MBartMoEConfig, MBartMoE) |
|
AutoModelForCausalLM.register(MBartMoEConfig, MBartMoE) |
|
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) |
|
|
|
model = LangVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) |
|
assert isinstance(model.decoder, MBartMoE) |
|
assert isinstance(model.encoder, VariableDonutSwinModel) |
|
|
|
model = model.to(device) |
|
model = model.eval() |
|
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") |
|
return model |
|
|
|
|
|
class LangVisionEncoderDecoderModel(VisionEncoderDecoderModel): |
|
def prepare_inputs_for_generation( |
|
self, input_ids, decoder_langs=None, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs |
|
): |
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, langs=decoder_langs, past_key_values=past_key_values) |
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None |
|
input_dict = { |
|
"attention_mask": attention_mask, |
|
"decoder_attention_mask": decoder_attention_mask, |
|
"decoder_input_ids": decoder_inputs["input_ids"], |
|
"encoder_outputs": encoder_outputs, |
|
"past_key_values": decoder_inputs["past_key_values"], |
|
"use_cache": use_cache, |
|
"decoder_langs": decoder_inputs["langs"], |
|
} |
|
return input_dict |
|
|
|
|