foo5 / modeling_custom5.py
denizyuret-shallowai's picture
Upload model
0e6dbed
# https://huggingface.co/docs/transformers/custom_models
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from torch.nn.functional import log_softmax
from torch.nn.modules.container import ModuleList
from .configuration_custom5 import CustomConfig5
class CustomModel5(PreTrainedModel):
config_class = CustomConfig5
def __init__(self, config):
super().__init__(config)
self.model = ModuleList([AutoModelForCausalLM.from_pretrained(m) for m in config.models])
def forward(self, *args, labels=None, **kwargs):
loss = None
logits = None
for model, coeff in zip(self.model, self.config.coeffs):
logp = log_softmax(model.forward(*args, **kwargs).logits, dim=-1)
logits = coeff * logp if logits is None else logits + coeff * logp
# The rest copied from modeling_llama.py:
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(loss=loss, logits=logits)
## Which one do we use?
## You have to tell the library you want to copy the code files of those objects when using the save_pretrained method and properly register them with a given Auto class (especially for models), just run:
# CustomConfig5.register_for_auto_class()
# CustomModel5.register_for_auto_class('AutoModelForCausalLM')
# CustomModel5.register_for_auto_class('AutoModel')
## If you are writing a library that extends 🤗 Transformers, you may want to extend the auto classes to include your own model. This is different from pushing the code to the Hub in the sense that users will need to import your library to get the custom models (contrarily to automatically downloading the model code from the Hub).
## As long as your config has a model_type attribute that is different from existing model types, and that your model classes have the right config_class attributes, you can just add them to the auto classes like this:
# AutoConfig.register("custom5", CustomConfig5)
# AutoModel.register(CustomConfig5, CustomModel5)
# AutoModelForCausalLM.register(CustomConfig5, CustomModel5)