import types import torch import torch.nn.functional as F import numpy as np from torch import nn from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList, PreTrainedModel from functools import partial from undecorate import unwrap from types import MethodType from utils import * from ling_disc import DebertaReplacedTokenizer from const import * def vae_sample(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu class VAE(nn.Module): def __init__(self, args): super().__init__() self.encoder = nn.Sequential( nn.Linear(args.input_dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), ) self.decoder = nn.Sequential( nn.Linear(args.latent_dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.input_dim), ) self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim) self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim) def forward(self, x): h = self.encoder(x) mu = self.fc_mu(h) logvar = self.fc_var(h) x = vae_sample(mu, logvar) o = self.decoder(x) return o, (mu, logvar) class LingGenerator(nn.Module): def __init__(self, args, hidden_dim=1000): super().__init__() self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small') self.hidden_size = self.gen.config.d_model self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size) # self.gen = nn.Sequential( # nn.Linear(args.lng_dim, 2*hidden_dim), # nn.ReLU(), # nn.BatchNorm1d(2*hidden_dim), # nn.Linear(2*hidden_dim, 2*hidden_dim), # nn.ReLU(), # nn.BatchNorm1d(2*hidden_dim), # nn.Linear(2*hidden_dim, hidden_dim), # nn.ReLU(), # ) self.gen_type = args.linggen_type self.gen_input = args.linggen_input if self.gen_type == 'vae': self.gen_mu = nn.Linear(hidden_dim, args.lng_dim) self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim) elif self.gen_type == 'det': self.projection = nn.Linear(self.hidden_size, args.lng_dim) def forward(self, batch): inputs_embeds = self.gen.shared(batch['sentence1_input_ids']) inputs_att_mask = batch['sentence1_attention_mask'] bs = inputs_embeds.shape[0] if self.gen_input == 's+l': sent1_ling = self.ling_embed(batch['sentence1_ling']) sent1_ling = sent1_ling.view(bs, 1, -1) inputs_embeds = inputs_embeds + sent1_ling gen = self.gen(inputs_embeds=inputs_embeds, attention_mask=inputs_att_mask).last_hidden_state.mean(1) # gen = self.gen(batch['sentence1_ling']) cache = {} if self.gen_type == 'vae': mu = self.gen_mu(gen) logvar = self.gen_logvar(gen) output = vae_sample(mu, logvar) cache['linggen_mu'] = mu cache['linggen_logvar'] = logvar elif self.gen_type == 'det': output = self.projection(gen) return output, cache class LingDisc(nn.Module): def __init__(self, model_name, disc_type, disc_ckpt, lng_dim=40, quant_nbins=1, disc_lng_dim=None, lng_ids=None, **kwargs): super().__init__() if disc_type == 't5': self.encoder = T5EncoderModel.from_pretrained(model_name) hidden_dim = self.encoder.config.d_model self.dropout = nn.Dropout(0.2) self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim self.quant = quant_nbins > 1 self.quant = False if self.quant: self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins) else: self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim) lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None # from const import used_indices # lng_ids = torch.tensor(used_indices) self.register_buffer('lng_ids', lng_ids) elif disc_type == 'deberta': self.encoder= DebertaReplacedTokenizer.from_pretrained( pretrained_model_name_or_path=disc_ckpt, tok_model_name = model_name, problem_type='regression', num_labels=40) self.quant = False self.disc_type = disc_type def forward(self, **batch): if not 'attention_mask' in batch: if 'input_ids' in batch: att_mask = torch.ones_like(batch['input_ids']) else: att_mask = torch.ones_like(batch['logits'])[:,:,0] else: att_mask = batch['attention_mask'] if 'input_ids' in batch: enc_output = self.encoder(input_ids=batch['input_ids'], attention_mask=att_mask) elif 'logits' in batch: logits = batch['logits'] scores = F.softmax(logits, dim = -1) onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) onehot_ = scores - scores.detach() + onehot embed_layer = self.encoder.get_input_embeddings() if isinstance(embed_layer, nn.Sequential): for i, module in enumerate(embed_layer): if i == 0: embeds = torch.matmul(onehot_, module.weight) else: embeds = module(embeds) else: embeds = onehot_ @ embed_layer.weight embeds = torch.matmul(onehot_, embed_layer.weight) enc_output = self.encoder(inputs_embeds=embeds, attention_mask=att_mask) if self.disc_type == 't5': sent_emb = self.dropout(enc_output.last_hidden_state.mean(1)) bs = sent_emb.shape[0] output = self.ling_classifier(sent_emb) if self.quant: output = output.reshape(bs, -1, self.lng_dim) if self.lng_ids is not None: output = torch.index_select(output, 1, self.lng_ids) elif self.disc_type == 'deberta': output = enc_output.logits return output class SemEmb(T5EncoderModel): def __init__(self, config, sep_token_id): super().__init__(config) self.sep_token_id = sep_token_id hidden_dim = self.config.d_model self.projection = nn.Sequential(nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, 1)) def compare_sem(self, **batch): bs = batch['sentence1_attention_mask'].shape[0] ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device) sep = torch.ones((bs, 1), dtype=torch.long, device=batch['sentence1_attention_mask'].device) * self.sep_token_id att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1) if 'logits' in batch: input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1) embeds1 = self.shared(input_ids) logits = batch['logits'] scores = F.softmax(logits, dim = -1) onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device) onehot_ = scores - scores.detach() + onehot embeds2 = onehot_ @ self.shared.weight embeds1_2 = torch.cat([embeds1, embeds2], dim=1) hidden_units = self(inputs_embeds=embeds1_2, attention_mask=att_mask).last_hidden_state.mean(1) elif 'sentence2_input_ids' in batch: input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1) hidden_units = self(input_ids=input_ids, attention_mask=att_mask).last_hidden_state.mean(1) probs = self.projection(hidden_units) return probs def prepare_inputs_for_generation( combine_method, ling2_only, self, input_ids, past_key_values=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, sent1_ling=None, sent2_ling=None, **kwargs ): # cut decoder_input_ids if past is used if past_key_values is not None: input_ids = input_ids[:, -1:] input_ids = input_ids.clone() decoder_inputs_embeds = self.shared(input_ids) if combine_method == 'decoder_add_first': sent2_ling = torch.cat([sent2_ling, torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1) if combine_method == 'decoder_concat': if ling2_only: decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1) else: decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1) elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'): if ling2_only: decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling else: decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling return { "decoder_inputs_embeds": decoder_inputs_embeds, "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } class LogitsAdd(LogitsProcessor): def __init__(self, sent2_ling): super().__init__() self.sent2_ling = sent2_ling def __call__(self, input_ids, scores): return scores + self.sent2_ling class EncoderDecoderVAE(T5ForConditionalGeneration): def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128): super().__init__(config) self.prepare_inputs_for_generation = types.MethodType( partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only), self) self.args = args self.pad_token_id = pad_token_id self.eos_token_id = sepeos_token_id hidden_dim = self.config.d_model if not 'logits' in args.combine_method else vocab_size if args.combine_method == 'fusion1': self.fusion = nn.Sequential( nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), ) elif args.combine_method == 'fusion2': self.fusion = nn.Sequential( nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), ) elif 'concat' in args.combine_method or 'add' in args.combine_method: if args.ling_embed_type == 'two-layer': self.ling_embed = nn.Sequential( nn.Linear(args.lng_dim, args.lng_dim), nn.ReLU(), nn.Linear(args.lng_dim, hidden_dim), ) else: self.ling_embed = nn.Linear(args.lng_dim, hidden_dim) self.ling_dropout = nn.Dropout(args.ling_dropout) if args.ling_vae: self.ling_mu = nn.Linear(hidden_dim, hidden_dim) self.ling_logvar = nn.Linear(hidden_dim, hidden_dim) nn.init.xavier_uniform_(self.ling_embed.weight) nn.init.xavier_uniform_(self.ling_mu.weight) nn.init.xavier_uniform_(self.ling_logvar.weight) generate_with_grad = unwrap(self.generate) self.generate_with_grad = MethodType(generate_with_grad, self) def get_fusion_layer(self): if 'fusion' in self.args.combine_method: return self.fusion elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method: return self.ling_embed else: return None def sample(self, mu, logvar): std = torch.exp(0.5 * logvar) return mu + std * torch.randn_like(std) def encode(self, batch): if 'inputs_embeds' in batch: inputs_embeds = batch['inputs_embeds'] else: inputs_embeds = self.shared(batch['sentence1_input_ids']) inputs_att_mask = batch['sentence1_attention_mask'] bs = inputs_embeds.shape[0] cache = {} if self.args.combine_method in ('input_concat', 'input_add'): if 'sent1_ling_embed' in batch: sent1_ling = batch['sent1_ling_embed'] else: sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling'])) if 'sent2_ling_embed' in batch: sent2_ling = batch['sent2_ling_embed'] else: sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling'])) if self.args.ling_vae: sent1_ling = F.leaky_relu(sent1_ling) sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling) sent1_ling = self.sample(sent1_mu, sent1_logvar) sent2_ling = F.leaky_relu(sent2_ling) sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling) sent2_ling = self.sample(sent2_mu, sent2_logvar) cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar, 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar, 'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling}) else: cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling}) sent1_ling = sent1_ling.view(bs, 1, -1) sent2_ling = sent2_ling.view(bs, 1, -1) if self.args.combine_method == 'input_concat': if self.args.ling2_only: inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1) inputs_att_mask = torch.cat([inputs_att_mask, torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1) else: inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1) inputs_att_mask = torch.cat([inputs_att_mask, torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1) elif self.args.combine_method == 'input_add': if self.args.ling2_only: inputs_embeds = inputs_embeds + sent2_ling else: inputs_embeds = inputs_embeds + sent1_ling + sent2_ling return self.encoder(inputs_embeds=inputs_embeds, attention_mask=inputs_att_mask), inputs_att_mask, cache def decode(self, batch, enc_output, inputs_att_mask, generate): bs = inputs_att_mask.shape[0] cache = {} if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'): if 'sent1_ling_embed' in batch: sent1_ling = batch['sent1_ling_embed'] elif 'sentence1_ling' in batch: sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling'])) else: sent1_ling = None if 'sent2_ling_embed' in batch: sent2_ling = batch['sent2_ling_embed'] else: sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling'])) if self.args.ling_vae: sent1_ling = F.leaky_relu(sent1_ling) sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling) sent1_ling = self.sample(sent1_mu, sent1_logvar) sent2_ling = F.leaky_relu(sent2_ling) sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling) sent2_ling = self.sample(sent2_mu, sent2_logvar) cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar, 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar, 'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling}) else: cache.update({'sent2_ling': sent2_ling}) if sent1_ling is not None: cache.update({'sent1_ling': sent1_ling}) if sent1_ling is not None: sent1_ling = sent1_ling.view(bs, 1, -1) sent2_ling = sent2_ling.view(bs, 1, -1) if self.args.combine_method == 'decoder_add_first' and not generate: sent2_ling = torch.cat([sent2_ling, torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1) else: sent1_ling, sent2_ling = None, None if self.args.combine_method == 'embed_concat': enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=1) inputs_att_mask = torch.cat([inputs_att_mask, torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1) elif 'fusion' in self.args.combine_method: sent1_ling = batch['sentence1_ling'].unsqueeze(1)\ .expand(-1, enc_output.last_hidden_state.shape[1], -1) sent2_ling = batch['sentence2_ling'].unsqueeze(1)\ .expand(-1, enc_output.last_hidden_state.shape[1], -1) if self.args.ling2_only: combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2) else: combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2) enc_output.last_hidden_state = self.fusion(combined_embedding) if generate: if self.args.combine_method == 'logits_add': logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))]) else: logits_processor = LogitsProcessorList() dec_output = self.generate_with_grad( attention_mask=inputs_att_mask, encoder_outputs=enc_output, sent1_ling=sent1_ling, sent2_ling=sent2_ling, return_dict_in_generate=True, output_scores=True, logits_processor = logits_processor, # renormalize_logits=True, # do_sample=True, # top_p=0.8, eos_token_id=self.eos_token_id, # min_new_tokens=3, # repetition_penalty=1.2, max_length=self.args.max_length, ) scores = torch.stack(dec_output.scores, 1) cache.update({'scores': scores}) return dec_output.sequences, cache decoder_input_ids = self._shift_right(batch['sentence2_input_ids']) decoder_inputs_embeds = self.shared(decoder_input_ids) decoder_att_mask = batch['sentence2_attention_mask'] labels = batch['sentence2_input_ids'].clone() labels[labels == self.pad_token_id] = -100 if self.args.combine_method == 'decoder_concat': if self.args.ling2_only: decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1) decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1) labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, labels], dim=1) else: decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1) decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1) labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id, labels], dim=1) elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' : if self.args.ling2_only: decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling else: decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling dec_output = self( decoder_inputs_embeds=decoder_inputs_embeds, decoder_attention_mask=decoder_att_mask, encoder_outputs=enc_output, attention_mask=inputs_att_mask, labels=labels, ) if self.args.combine_method == 'logits_add': dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling vocab_size = dec_output.logits.size(-1) dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1)) return dec_output, cache def convert(self, batch, generate=False): enc_output, enc_att_mask, cache = self.encode(batch) dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate) cache.update(cache2) return dec_output, enc_output, cache def infer_with_cache(self, batch): dec_output, _, cache = self.convert(batch, generate = True) return dec_output, cache def infer(self, batch): dec_output, _ = self.infer_with_cache(batch) return dec_output def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer): from torch.autograd import grad interpolations = [] def line_search(): best_val = None best_loss = None eta = 1e3 sem_prob = 1 patience = 4 while patience > 0: param_ = param - eta * grads with torch.no_grad(): new_loss, pred = get_loss(param_) max_len = pred.shape[1] lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1 batch.update({ 'sentence2_input_ids': pred, 'sentence2_attention_mask': sequence_mask(lens, max_len = max_len) }) sem_prob = torch.sigmoid(sem_emb.compare_sem(**batch)).item() # if sem_prob <= 0.1: # patience -= 1 if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1: return param_ eta *= 2.25 patience -= 1 return False def get_loss(param): if self.args.feedback_param == 'l': batch.update({'sent2_ling_embed': param}) elif self.args.feedback_param == 's': batch.update({'inputs_embeds': param}) if self.args.feedback_param == 'logits': logits = param pred = param.argmax(-1) else: pred, cache = self.infer_with_cache(batch) logits = cache['scores'] out = ling_disc(logits = logits) probs = F.softmax(out, 1) if ling_disc.quant: loss = F.cross_entropy(out, batch['sentence2_discr']) else: loss = F.mse_loss(out, batch['sentence2_ling']) return loss, pred if self.args.feedback_param == 'l': ling2_embed = self.ling_embed(batch['sentence2_ling']) param = torch.nn.Parameter(ling2_embed, requires_grad = True) elif self.args.feedback_param == 's': inputs_embeds = self.shared(batch['sentence1_input_ids']) param = torch.nn.Parameter(inputs_embeds, requires_grad = True) elif self.args.feedback_param == 'logits': logits = self.infer_with_cache(batch)[1]['scores'] param = torch.nn.Parameter(logits, requires_grad = True) target_np = batch['sentence2_ling'][0].cpu().numpy() while True: loss, pred = get_loss(param) pred_text = tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0] interpolations.append(pred_text) if loss < 1: break self.zero_grad() grads = grad(loss, param)[0] param = line_search() if param is False: break return pred, [pred_text, interpolations] def set_grad(module, state): if module is not None: for p in module.parameters(): p.requires_grad = state def set_grad_except(model, name, state): for n, p in model.named_parameters(): if not name in n: p.requires_grad = state class SemEmbPipeline(): def __init__(self, ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"): self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['']) state = torch.load(ckpt) self.model.load_state_dict(state['model'], strict=False) self.model.eval() self.model.cuda() def __call__(self, sentence1, sentence2): sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt') sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt') sem_logit = self.model( sentence1_input_ids = sentence1.input_ids.cuda(), sentence1_attention_mask = sentence1.attention_mask.cuda(), sentence2_input_ids = sentence2.input_ids.cuda(), sentence2_attention_mask = sentence2.attention_mask.cuda(), ) sem_prob = torch.sigmoid(sem_logit).item() return sem_prob class LingDiscPipeline(): def __init__(self, model_name="google/flan-t5-base", disc_type='deberta', disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40', # disc_type='t5', # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt', ): self.tokenizer = T5Tokenizer.from_pretrained(model_name) self.model = LingDisc(model_name, disc_type, disc_ckpt) self.model.eval() self.model.cuda() def __call__(self, sentence): inputs = self.tokenizer(sentence, return_tensors = 'pt') with torch.no_grad(): ling_pred = self.model(input_ids=inputs.input_ids.cuda()) return ling_pred def get_model(args, tokenizer, device): if args.pretrain_disc or args.disc_loss or args.disc_ckpt: ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device) else: ling_disc = None if args.linggen_type != 'none': ling_gen = LingGenerator(args).to(device) if not args.pretrain_disc: model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device) else: model = ling_disc if args.sem_loss or args.sem_ckpt: if args.sem_loss_type == 'shared': sem_emb = model.encoder elif args.sem_loss_type == 'dedicated': sem_emb = SemEmb.from_pretrained(args.sem_model_path, tokenizer.eos_token_id).to(device) else: raise NotImplementedError('Semantic loss type') else: sem_emb = None return model, ling_disc, sem_emb