Flax version

#1
by iohadrubin - opened
from transformers import AutoTokenizer
from transformers import FlaxT5PreTrainedModel, FlaxT5EncoderModel,FlaxT5EncoderModel
from transformers.models.t5.modeling_flax_t5 import FlaxT5EncoderModule
from transformers import T5Config

import flax.linen as nn 
from typing import Callable, Optional, Tuple

import numpy as np

import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey


class ModuleT5EncoderWithProjection(nn.Module):
    config: T5Config
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # input_shape: Tuple[int] = (1, 1),
    gradient_checkpointing: bool = False
    def setup(self):
    #   print(input_args)
      self.t5_encoder = FlaxT5EncoderModule(self.config)
      self.projection = nn.Dense(self.config.d_model,use_bias=False)
    def __call__(self,
        input_ids=None,
        attention_mask=None, output_attentions=None,
        output_hidden_states=None,return_dict: bool = False,deterministic: bool = True,rngs=None):
      last_hidden_state = self.t5_encoder(input_ids,attention_mask=attention_mask)[0]
      last_hidden_state = last_hidden_state[:, 0, :]
      return self.projection(last_hidden_state)
class T5EncoderWithProjection(FlaxT5EncoderModel):
    module_class: nn.Module = ModuleT5EncoderWithProjection
    def __init__(self, config,_do_init,dtype):
        super().__init__(config,_do_init=_do_init,dtype=dtype)
    def __call__(
            self,
            input_ids: jnp.ndarray,
            attention_mask: Optional[jnp.ndarray] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            train: bool = False,
            params: dict = None,
            dropout_rng: PRNGKey = None,
        ):
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # prepare encoder inputs
            if attention_mask is None:
                attention_mask = jnp.ones_like(input_ids)

            # Handle any PRNG if needed
            rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

            return self.module.apply(
                {"params": params or self.params},
                input_ids=jnp.array(input_ids, dtype="i4"),
                attention_mask=jnp.array(attention_mask, dtype="i4"),
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=not train,
                rngs=rngs,
            )
encoder =T5EncoderWithProjection.from_pretrained("kalpeshk2011/rankgen-t5-base-all",from_pt=True)

Sign up or log in to comment