Flax version
#1
by
iohadrubin
- opened
from transformers import AutoTokenizer
from transformers import FlaxT5PreTrainedModel, FlaxT5EncoderModel,FlaxT5EncoderModel
from transformers.models.t5.modeling_flax_t5 import FlaxT5EncoderModule
from transformers import T5Config
import flax.linen as nn
from typing import Callable, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
class ModuleT5EncoderWithProjection(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# input_shape: Tuple[int] = (1, 1),
gradient_checkpointing: bool = False
def setup(self):
# print(input_args)
self.t5_encoder = FlaxT5EncoderModule(self.config)
self.projection = nn.Dense(self.config.d_model,use_bias=False)
def __call__(self,
input_ids=None,
attention_mask=None, output_attentions=None,
output_hidden_states=None,return_dict: bool = False,deterministic: bool = True,rngs=None):
last_hidden_state = self.t5_encoder(input_ids,attention_mask=attention_mask)[0]
last_hidden_state = last_hidden_state[:, 0, :]
return self.projection(last_hidden_state)
class T5EncoderWithProjection(FlaxT5EncoderModel):
module_class: nn.Module = ModuleT5EncoderWithProjection
def __init__(self, config,_do_init,dtype):
super().__init__(config,_do_init=_do_init,dtype=dtype)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
encoder =T5EncoderWithProjection.from_pretrained("kalpeshk2011/rankgen-t5-base-all",from_pt=True)