Replace deprecated `t5x.partitioning.ModelBasedPjitPartitioner` with `t5x.partitioning.PjitPartitioner`

#2
by deleted - opened
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -85,7 +85,7 @@ class InferenceModel(object):
85
  self.sequence_length = {'inputs': self.inputs_length,
86
  'targets': self.outputs_length}
87
 
88
- self.partitioner = t5x.partitioning.ModelBasedPjitPartitioner(
89
  model_parallel_submesh=(1, 1, 1, 1), num_partitions=1)
90
 
91
  # Build Codecs and Vocabularies.
@@ -178,7 +178,6 @@ class InferenceModel(object):
178
 
179
  Args:
180
  audio: 1-d numpy array of audio samples (16kHz) for a single example.
181
-
182
  Returns:
183
  A note_sequence of the transcribed audio.
184
  """
 
85
  self.sequence_length = {'inputs': self.inputs_length,
86
  'targets': self.outputs_length}
87
 
88
+ self.partitioner = t5x.partitioning.PjitPartitioner(
89
  model_parallel_submesh=(1, 1, 1, 1), num_partitions=1)
90
 
91
  # Build Codecs and Vocabularies.
 
178
 
179
  Args:
180
  audio: 1-d numpy array of audio samples (16kHz) for a single example.
 
181
  Returns:
182
  A note_sequence of the transcribed audio.
183
  """