sberhe commited on
Commit
9363e7a
·
1 Parent(s): 680cd70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -30,7 +30,7 @@ def extract_embeddings(batch):
30
  inputs = {k: tf.convert_to_tensor(v) for k, v in batch.items() if k in tokenizer.model_input_names}
31
  outputs = model(**inputs, output_hidden_states=True, return_dict=True)
32
  embeddings = outputs.last_hidden_state
33
- return {"embeddings": embeddings}
34
 
35
  # Apply the function to extract embeddings in batches
36
  embeddings_dataset = tokenized_datasets.map(extract_embeddings, batched=True, batch_size=batch_size)
 
30
  inputs = {k: tf.convert_to_tensor(v) for k, v in batch.items() if k in tokenizer.model_input_names}
31
  outputs = model(**inputs, output_hidden_states=True, return_dict=True)
32
  embeddings = outputs.last_hidden_state
33
+ return {"embeddings": embeddings.numpy()} # Ensure the "embeddings" key is present
34
 
35
  # Apply the function to extract embeddings in batches
36
  embeddings_dataset = tokenized_datasets.map(extract_embeddings, batched=True, batch_size=batch_size)