Spaces:
Running
Running
fix: should be converted to array
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -490,7 +490,7 @@ def main():
|
|
490 |
jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
|
491 |
)
|
492 |
|
493 |
-
model_inputs["decoder_input_ids"] = decoder_input_ids
|
494 |
|
495 |
return model_inputs
|
496 |
|
|
|
490 |
jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
|
491 |
)
|
492 |
|
493 |
+
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
494 |
|
495 |
return model_inputs
|
496 |
|