#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Sep 5 10:29:03 2023 @author: peter """ import transformers import torch class QaracDecoderHead(torch.nn.Module): def __init__(self,config): """ Creates the Decoder head Parameters ---------- config : transformers.RobertaConfig Config for the RobertaModel that this head will be attached to. Returns ------- None. """ super(QaracDecoderHead,self).__init__() self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config) self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config) self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config) def forward(self, vector, hidden_states, attention_mask=None): """ Predicts text fron vector and hidden states of base model Parameters ---------- inputs : tuple of tensorflow.Tensors Vector to be decoded and last hidden states of base model Returns ------- transformers.modeling_outputs.CausalLMOutputWithCrossAttentions Predicted text """ vectors = torch.cat([vector, hidden_states], dim=1) attentions = attention_mask if attention_mask is None else torch.cat([torch.ones((hidden_states.shape(0), 1)), attention_mask]) l0 = self.layer_0(vectors, attentions) return self.head(self.layer_1(l0[0][:,1:], attention_mask)[0]) class QaracDecoderModel(transformers.RobertaModel, transformers.generation_utils.GenerationMixin): def __init__(self,model_path,config,tokenizer): """ Creates decoder model from base model Parameters ---------- base_model : transformers.TFRobertaModel The base model Returns ------- None. """ super(QaracDecoderModel,self).__init__(config) self.decoder_base = transformers.RobertaModel.from_pretrained(model_path, config=config) self.decoder_head = QaracDecoderHead(self.config) self.tokenizer = tokenizer def forward(self,inputs,**kwargs): """ Predicts text from inputsBleakley Parameters ---------- inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor Vector to be converted to text and seed text OR tokenized seed text kwargs : optional keyword arguments vector : tensorflow.Tensor vector to be decoded. May be supplied via a keyword argument when this is invoked by .generate Returns ------- transformers.modeling_outputs.CausalLMOutputWithCrossAttentions Predicted text """ (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs (seed,attention_mask) = (s['input_ids'],s['attention_mask']) if 'attention_mask' in s else (s,None) return self.decoder_head(torch.unsqueeze(v,1), self.decoder_base(seed, attention_mask=attention_mask, use_cache='vector' in kwargs).last_hidden_state) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): if attention_mask is None: attention_mask = torch.ones_like(input_ids) return {'input_ids':input_ids, 'attention_mask':attention_mask}