|
from tensorflow.keras import layers |
|
from tensorflow.keras import Model |
|
import tensorflow as tf |
|
from transformers import TFPreTrainedModel |
|
|
|
valid_types = ["gpt2", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"] |
|
|
|
|
|
class Transformer(Model): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.remaining_frames_method = self.get_remaining_frames_method(config) |
|
self.transformer_type = self.get_transformer_type(config) |
|
self.transformer = self.load_transformer( |
|
self.remaining_frames_method, self.transformer_type |
|
) |
|
|
|
def get_transformer_type(self, config): |
|
if "transformer_type" in config: |
|
transformer_type = config["transformer_type"] |
|
if transformer_type not in valid_types: |
|
raise ValueError( |
|
f"transformer_type {transformer_type} is not valid. Valid types are {valid_types}" |
|
) |
|
return transformer_type |
|
else: |
|
return valid_types[0] |
|
|
|
def get_remaining_frames_method(self, config) -> str: |
|
"""Get the method to use for remaining frames. |
|
Check if the method is set inside the configuration, otherwise use concat as the default. |
|
""" |
|
if "remaining_frames_method" in config: |
|
return config["remaining_frames_method"] |
|
else: |
|
return "concat" |
|
|
|
def load_transformer(self, method: str, transformer_type: str) -> TFPreTrainedModel: |
|
print("using method ", method) |
|
if method == "own_embeddings": |
|
from ganime.model.vqgan_clean.experimental.gpt2_embedding import ( |
|
TFGPT2LMHeadModel, |
|
) |
|
|
|
transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type) |
|
|
|
else: |
|
from transformers import TFGPT2LMHeadModel |
|
|
|
transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type) |
|
return transformer |
|
|
|
def concatenate_inputs( |
|
self, remaining_frames, last_frame_indices, previous_frame_indices |
|
) -> tf.Tensor: |
|
if self.remaining_frames_method == "concat": |
|
return tf.concat( |
|
[remaining_frames, last_frame_indices, previous_frame_indices], axis=1 |
|
) |
|
else: |
|
return tf.concat([last_frame_indices, previous_frame_indices], axis=1) |
|
|
|
def call_transformer( |
|
self, transformer_input, remaining_frames, training, attention_mask |
|
): |
|
if self.remaining_frames_method == "concat": |
|
return self.transformer( |
|
transformer_input, training=training, attention_mask=attention_mask |
|
) |
|
elif self.remaining_frames_method == "token_type_ids": |
|
return self.transformer( |
|
transformer_input, |
|
token_type_ids=remaining_frames, |
|
training=training, |
|
attention_mask=attention_mask, |
|
) |
|
elif self.remaining_frames_method == "own_embeddings": |
|
return self.transformer( |
|
transformer_input, |
|
remaining_frames_ids=remaining_frames, |
|
training=training, |
|
attention_mask=attention_mask, |
|
) |
|
else: |
|
raise ValueError( |
|
f"Unknown remaining_frames_method {self.remaining_frames_method}" |
|
) |
|
|
|
def call(self, inputs, training=True, mask=None): |
|
remaining_frames, last_frame_indices, previous_frame_indices = inputs |
|
remaining_frames = tf.expand_dims(remaining_frames, axis=1) |
|
shape_to_keep = tf.shape(last_frame_indices)[1] |
|
|
|
h = self.concatenate_inputs( |
|
remaining_frames, last_frame_indices, previous_frame_indices |
|
) |
|
|
|
|
|
transformer_input = h |
|
mask = tf.ones_like(transformer_input) * tf.cast( |
|
tf.cast(remaining_frames, dtype=tf.bool), dtype=remaining_frames.dtype |
|
) |
|
|
|
h = self.call_transformer(transformer_input, remaining_frames, training, mask) |
|
h = h.logits |
|
|
|
h = h[:, -shape_to_keep:] |
|
return h |
|
|