rianders commited on
Commit
95c80a2
1 Parent(s): 24297f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -11,19 +11,20 @@ def get_bert_embeddings(words):
11
  model = BertModel.from_pretrained('bert-base-uncased')
12
  embeddings = []
13
 
14
- # Extract embeddings
15
  for word in words:
16
  inputs = tokenizer(word, return_tensors='pt')
17
  outputs = model(**inputs)
18
- embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())
 
 
19
 
20
- # Reduce dimensions to 3 using PCA
21
  if len(embeddings) > 0:
22
  pca = PCA(n_components=3)
23
  reduced_embeddings = pca.fit_transform(np.array(embeddings))
24
  return reduced_embeddings
25
  return []
26
 
 
27
  # Plotly plotting function
28
  def plot_interactive_bert_embeddings(embeddings, words):
29
  if len(words) < 4:
 
11
  model = BertModel.from_pretrained('bert-base-uncased')
12
  embeddings = []
13
 
 
14
  for word in words:
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
+ # Use the [CLS] token's embedding
18
+ cls_embedding = outputs.last_hidden_state[0][0].detach().numpy()
19
+ embeddings.append(cls_embedding)
20
 
 
21
  if len(embeddings) > 0:
22
  pca = PCA(n_components=3)
23
  reduced_embeddings = pca.fit_transform(np.array(embeddings))
24
  return reduced_embeddings
25
  return []
26
 
27
+
28
  # Plotly plotting function
29
  def plot_interactive_bert_embeddings(embeddings, words):
30
  if len(words) < 4: