from multilingual_clip import Config_MCLIP import tensorflow as tf import transformers class SentenceModel(tf.keras.Model): def __init__(self, modelBase, from_pt=True, *args, **kwargs): super().__init__(*args, **kwargs) self.transformer = transformers.TFAutoModel.from_pretrained(modelBase, from_pt=from_pt) @tf.function def generateMeanPooledSentenceEmbs(self, input, training=False): output = self.transformer(input, training=training) hiddenStates = output['last_hidden_state'] outAtt = tf.cast(input['attention_mask'], tf.float32) sampleLength = tf.reduce_sum(outAtt, axis=-1, keepdims=True) maskedEmbs = hiddenStates * tf.expand_dims(outAtt, axis=-1) return tf.reduce_sum(maskedEmbs, axis=1) / tf.cast(sampleLength, tf.float32) @tf.function def call(self, inputs, training=False, mask=None): return self.generateMeanPooledSentenceEmbs(inputs, training) class SentenceModelWithLinearTransformation(SentenceModel): def __init__(self, modelBase, embeddingSize=640, *args, **kwargs): super().__init__(modelBase, *args, **kwargs) self.postTransformation = tf.keras.layers.Dense(embeddingSize, activation='linear', name='LinearTransformation') @tf.function def call(self, inputs, training=False, mask=None): return self.postTransformation(self.generateMeanPooledSentenceEmbs(inputs, training)) class MultiLingualCLIP(transformers.TFPreTrainedModel): config_class = Config_MCLIP.MCLIPConfig @property def dummy_inputs(self): return {'input_ids': tf.ones((4, 12), tf.int32), 'attention_mask': tf.ones((4, 12), tf.int32)} @tf.function( input_signature=[ tf.TensorSpec((None, None), tf.int32), tf.TensorSpec((None, None), tf.int32) ] ) def serving(self, ids, att): output = self.call((ids, att)) return self.serving_output(output) def serving_output(self, outputs): return outputs def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.sentenceModel = SentenceModelWithLinearTransformation(config.modelBase, config.numDims) @tf.function def call(self, inputs, training=False, mask=None): return self.sentenceModel.call(inputs, training)