import os from typing import List, Union import numpy as np import math import time import heapq import torch from torch import Tensor, nn from torch.distributions.distribution import Distribution from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer import random from typing import Optional from .tools.token_emb import NewTokenEmb class MLM(nn.Module): def __init__( self, model_path: str, model_type: str = "t5", stage: str = "lm_pretrain", new_token_type: str = "insert", motion_codebook_size: int = 512, framerate: float = 20.0, down_t: int = 4, predict_ratio: float = 0.2, inbetween_ratio: float = 0.25, max_length: int = 256, lora: bool = False, quota_ratio: float = 0.5, noise_density: float = 0.15, mean_noise_span_length: int = 3, **kwargs, ) -> None: super().__init__() # Parameters self.m_codebook_size = motion_codebook_size self.max_length = max_length self.framerate = framerate self.down_t = down_t self.predict_ratio = predict_ratio self.inbetween_ratio = inbetween_ratio self.noise_density = noise_density self.mean_noise_span_length = mean_noise_span_length self.quota_ratio = quota_ratio self.stage = stage # Instantiate language model self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True) if model_type == "t5": self.language_model = T5ForConditionalGeneration.from_pretrained( model_path) self.lm_type = 'encdec' elif model_type == "gpt2": self.language_model = GPT2LMHeadModel.from_pretrained(model_path) self.lm_type = 'dec' else: raise ValueError("type must be either seq2seq or conditional") if self.lm_type == 'dec': self.tokenizer.pad_token = self.tokenizer.eos_token # Add motion tokens self.tokenizer.add_tokens( [f'' for i in range(self.m_codebook_size + 3)]) if new_token_type == "insert": self.language_model.resize_token_embeddings(len(self.tokenizer)) elif new_token_type == "mlp": shared = NewTokenEmb(self.language_model.shared, self.m_codebook_size + 3) # lm_head = NewTokenEmb(self.language_model.lm_head, # self.m_codebook_size + 3) self.language_model.resize_token_embeddings(len(self.tokenizer)) self.language_model.shared = shared # self.language_model.lm_head = lm_head # Lora if lora: from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict from peft.utils.other import fsdp_auto_wrap_policy peft_config = LoraConfig( bias="none", task_type="CAUSAL_LM", # inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05) self.language_model = get_peft_model(self.language_model, peft_config) def forward(self, texts: List[str], motion_tokens: Tensor, lengths: List[int], tasks: dict): if self.lm_type == 'encdec': return self.forward_encdec(texts, motion_tokens, lengths, tasks) elif self.lm_type == 'dec': return self.forward_dec(texts, motion_tokens, lengths, tasks) else: raise NotImplementedError("Only conditional_multitask supported") def forward_encdec( self, texts: List[str], motion_tokens: Tensor, lengths: List[int], tasks: dict, ): # Tensor to string motion_strings = self.motion_token_to_string(motion_tokens, lengths) # Supervised or unsupervised # condition = random.choice( # ['text', 'motion', 'supervised', 'supervised', 'supervised']) condition = random.choice(['supervised', 'supervised', 'supervised']) if condition == 'text': inputs = texts outputs = texts elif condition == 'motion': inputs = motion_strings outputs = motion_strings else: inputs, outputs = self.template_fulfill(tasks, lengths, motion_strings, texts) # Tokenize source_encoding = self.tokenizer(inputs, padding='max_length', max_length=self.max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt") source_attention_mask = source_encoding.attention_mask.to( motion_tokens.device) source_input_ids = source_encoding.input_ids.to(motion_tokens.device) if condition in ['text', 'motion']: batch_size, expandend_input_length = source_input_ids.shape mask_indices = np.asarray([ self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size) ]) target_mask = ~mask_indices input_ids_sentinel = self.create_sentinel_ids( mask_indices.astype(np.int8)) target_sentinel = self.create_sentinel_ids( target_mask.astype(np.int8)) labels_input_ids = self.filter_input_ids(source_input_ids, target_sentinel) source_input_ids = self.filter_input_ids(source_input_ids, input_ids_sentinel) else: target_inputs = self.tokenizer(outputs, padding='max_length', max_length=self.max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt") labels_input_ids = target_inputs.input_ids.to(motion_tokens.device) lables_attention_mask = target_inputs.attention_mask.to( motion_tokens.device) labels_input_ids[labels_input_ids == 0] = -100 outputs = self.language_model( input_ids=source_input_ids, attention_mask=source_attention_mask if condition == 'supervised' else None, labels=labels_input_ids, decoder_attention_mask=lables_attention_mask if condition == 'supervised' else None, ) return outputs def forward_dec( self, texts: List[str], motion_tokens: Tensor, lengths: List[int], tasks: dict, ): self.tokenizer.padding_side = "right" # Tensor to string motion_strings = self.motion_token_to_string(motion_tokens, lengths) # Supervised or unsupervised condition = random.choice( ['text', 'motion', 'supervised', 'supervised', 'supervised']) if condition == 'text': labels = texts elif condition == 'motion': labels = motion_strings else: inputs, outputs = self.template_fulfill(tasks, lengths, motion_strings, texts) labels = [] for i in range(len(inputs)): labels.append(inputs[i] + ' \n ' + outputs[i] + self.tokenizer.eos_token) # Tokenize inputs = self.tokenizer(labels, padding='max_length', max_length=self.max_length, truncation=True, return_attention_mask=True, return_tensors="pt") labels_input_ids = inputs.input_ids.to(motion_tokens.device) lables_attention_mask = inputs.attention_mask.to(motion_tokens.device) # print(labels_input_ids[0:5]) outputs = self.language_model(input_ids=labels_input_ids, attention_mask=lables_attention_mask, labels=inputs["input_ids"]) return outputs def generate_direct(self, texts: List[str], max_length: int = 256, num_beams: int = 1, do_sample: bool = True, bad_words_ids: List[int] = None): # Device self.device = self.language_model.device # Tokenize if self.lm_type == 'dec': texts = [text + " \n " for text in texts] source_encoding = self.tokenizer(texts, padding='max_length', max_length=self.max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt") source_input_ids = source_encoding.input_ids.to(self.device) source_attention_mask = source_encoding.attention_mask.to(self.device) if self.lm_type == 'encdec': outputs = self.language_model.generate( source_input_ids, max_length=max_length, num_beams=num_beams, do_sample=do_sample, bad_words_ids=bad_words_ids, ) elif self.lm_type == 'dec': outputs = self.language_model.generate( input_ids=source_input_ids, attention_mask=source_attention_mask, pad_token_id=self.tokenizer.pad_token_id, do_sample=do_sample, max_new_tokens=max_length) self.tokenizer.padding_side = 'left' outputs_string = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) print(texts[:2]) print(outputs_string[:2]) outputs_tokens, cleaned_text = self.motion_string_to_token( outputs_string) return outputs_tokens, cleaned_text def generate_conditional(self, texts: Optional[List[str]] = None, motion_tokens: Optional[Tensor] = None, lengths: Optional[List[int]] = None, task: str = "t2m", with_len: bool = False, stage: str = 'train', tasks: dict = None): self.device = self.language_model.device if task in ["t2m", "m2m", "pred", "inbetween"]: if task == "t2m": assert texts is not None motion_strings = [''] * len(texts) if not with_len: if tasks is None: tasks = [{ 'input': ['Generate motion: '], 'output': [''] }] * len(texts) lengths = [0] * len(texts) else: tasks = [{ 'input': [ 'Generate motion with frames: ' ], 'output': [''] }] * len(texts) elif task == "pred": assert motion_tokens is not None and lengths is not None texts = [''] * len(lengths) tasks = [{ 'input': ['Predict motion: '], 'output': [''] }] * len(lengths) motion_strings_old = self.motion_token_to_string( motion_tokens, lengths) motion_strings = [] for i, length in enumerate(lengths): split = length // 5 motion_strings.append( '>'.join(motion_strings_old[i].split('>')[:split]) + '>') elif task == "inbetween": assert motion_tokens is not None and lengths is not None texts = [''] * len(lengths) tasks = [{ 'input': [ "Complete the masked motion: " ], 'output': [''] }] * len(lengths) motion_strings = self.motion_token_to_string( motion_tokens, lengths) inputs, outputs = self.template_fulfill(tasks, lengths, motion_strings, texts, stage) outputs_tokens, cleaned_text = self.generate_direct(inputs, max_length=128, num_beams=1, do_sample=True) return outputs_tokens elif task == "m2t": assert motion_tokens is not None and lengths is not None motion_strings = self.motion_token_to_string( motion_tokens, lengths) if not with_len: tasks = [{ 'input': ['Generate text: '], 'output': [''] }] * len(lengths) else: tasks = [{ 'input': [ 'Generate text with frames: ' ], 'output': [''] }] * len(lengths) texts = [''] * len(lengths) inputs, outputs = self.template_fulfill(tasks, lengths, motion_strings, texts) outputs_tokens, cleaned_text = self.generate_direct( inputs, max_length=40, num_beams=1, do_sample=False, # bad_words_ids=self.bad_words_ids ) return cleaned_text def motion_token_to_string(self, motion_token: Tensor, lengths: List[int]): motion_string = [] for i in range(len(motion_token)): motion_i = motion_token[i].cpu( ) if motion_token[i].device.type == 'cuda' else motion_token[i] motion_list = motion_i.tolist()[:lengths[i]] motion_string.append( (f'' + ''.join([f'' for i in motion_list]) + f'')) return motion_string def motion_token_list_to_string(self, motion_token: Tensor): motion_string = [] for i in range(len(motion_token)): motion_i = motion_token[i].cpu( ) if motion_token[i].device.type == 'cuda' else motion_token[i] motion_list = motion_i.tolist() motion_string.append( (f'' + ''.join([f'' for i in motion_list]) + f'')) return motion_string def motion_string_to_token(self, motion_string: List[str]): motion_tokens = [] output_string = [] for i in range(len(motion_string)): string = self.get_middle_str( motion_string[i], f'', f'') string_list = string.split('><') token_list = [ int(i.split('_')[-1].replace('>', '')) for i in string_list[1:-1] ] if len(token_list) == 0: token_list = [0] token_list_padded = torch.tensor(token_list, dtype=int).to(self.device) motion_tokens.append(token_list_padded) output_string.append(motion_string[i].replace( string, '')) return motion_tokens, output_string def placeholder_fulfill(self, prompt: str, length: int, motion_string: str, text: str): seconds = math.floor(length / self.framerate) motion_splited = motion_string.split('>') token_length = length / self.down_t predict_head = int(token_length * self.predict_ratio + 1) masked_head = int(token_length * self.inbetween_ratio + 1) masked_tail = int(token_length * (1 - self.inbetween_ratio) + 1) motion_predict_head = '>'.join( motion_splited[:predict_head] ) + f'>' motion_predict_last = f'' + '>'.join( motion_splited[predict_head:]) motion_masked = '>'.join( motion_splited[:masked_head] ) + '>' + f'' * ( masked_tail - masked_head) + '>'.join(motion_splited[masked_tail:]) if random.random() < self.quota_ratio: text = f'\"{text}\"' prompt = prompt.replace('', text).replace( '', motion_string).replace('', f'{length}').replace( '', '%.1f' % seconds).replace( '', motion_predict_head).replace( '', motion_predict_last).replace( '', motion_masked) return prompt def template_fulfill(self, tasks, lengths, motion_strings, texts, stage='test'): inputs = [] outputs = [] for i in range(len(lengths)): input_template = random.choice(tasks[i]['input']) output_template = random.choice(tasks[i]['output']) length = lengths[i] inputs.append( self.placeholder_fulfill(input_template, length, motion_strings[i], texts[i])) outputs.append( self.placeholder_fulfill(output_template, length, motion_strings[i], texts[i])) return inputs, outputs def get_middle_str(self, content, startStr, endStr): try: startIndex = content.index(startStr) if startIndex >= 0: startIndex += len(startStr) endIndex = content.index(endStr) except: return f'' return f'' + content[ startIndex:endIndex] + f'' def random_spans_noise_mask(self, length): # From https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py orig_length = length num_noise_tokens = int(np.round(length * self.noise_density)) # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) num_noise_spans = int( np.round(num_noise_tokens / self.mean_noise_span_length)) # avoid degeneracy by ensuring positive number of noise spans num_noise_spans = max(num_noise_spans, 1) num_nonnoise_tokens = length - num_noise_tokens # pick the lengths of the noise spans and the non-noise spans def _random_segmentation(num_items, num_segments): """Partition a sequence of items randomly into non-empty segments. Args: num_items: an integer scalar > 0 num_segments: an integer scalar in [1, num_items] Returns: a Tensor with shape [num_segments] containing positive integers that add up to num_items """ mask_indices = np.arange(num_items - 1) < (num_segments - 1) np.random.shuffle(mask_indices) first_in_segment = np.pad(mask_indices, [[1, 0]]) segment_id = np.cumsum(first_in_segment) # count length of sub segments assuming that list is sorted _, segment_length = np.unique(segment_id, return_counts=True) return segment_length noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) interleaved_span_lengths = np.reshape( np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2], ) span_starts = np.cumsum(interleaved_span_lengths)[:-1] span_start_indicator = np.zeros((length, ), dtype=np.int8) span_start_indicator[span_starts] = True span_num = np.cumsum(span_start_indicator) is_noise = np.equal(span_num % 2, 1) return is_noise[:orig_length] def create_sentinel_ids(self, mask_indices): # From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices start_indices[:, 0] = mask_indices[:, 0] sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0) sentinel_ids -= mask_indices - start_indices return sentinel_ids def filter_input_ids(self, input_ids, sentinel_ids): # From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py batch_size = input_ids.shape[0] input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids.to('cpu')) # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are # masked tokens coming after sentinel tokens and should be removed input_ids = input_ids_full[input_ids_full >= 0].reshape( (batch_size, -1)) input_ids = np.concatenate( [ input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32), ], axis=-1, ) input_ids = torch.tensor(input_ids, device=self.device) return input_ids