ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
import warnings
import torch
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
import logging
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
from typing import List, Optional, Tuple
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM
from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig
from surya.model.recognition.encoder import VariableDonutSwinModel
from surya.model.recognition.decoder import MBartMoE
from surya.settings import settings
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
# Prune moe experts that are not needed before loading the model
if langs is not None:
config.decoder.langs = {lang_iso : lang_int for lang_iso, lang_int in config.decoder.langs.items() if lang_int in langs}
decoder_config = vars(config.decoder)
decoder = MBartMoEConfig(**decoder_config)
config.decoder = decoder
encoder_config = vars(config.encoder)
encoder = VariableDonutSwinConfig(**encoder_config)
config.encoder = encoder
# Get transformers to load custom encoder/decoder
AutoModel.register(MBartMoEConfig, MBartMoE)
AutoModelForCausalLM.register(MBartMoEConfig, MBartMoE)
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
model = LangVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
assert isinstance(model.decoder, MBartMoE)
assert isinstance(model.encoder, VariableDonutSwinModel)
model = model.to(device)
model = model.eval()
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
return model
class LangVisionEncoderDecoderModel(VisionEncoderDecoderModel):
def prepare_inputs_for_generation(
self, input_ids, decoder_langs=None, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, langs=decoder_langs, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
"decoder_langs": decoder_inputs["langs"],
}
return input_dict