File size: 904 Bytes
30e9731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import json
import numpy as np

from transformers import (BertForMaskedLM, BertTokenizer)

modelpath = 'zari-bert-cda/'
tokenizer = BertTokenizer.from_pretrained(modelpath)
model = BertForMaskedLM.from_pretrained(modelpath)
model.eval()

id_of_mask = 103

def get_embeddings(sentence):
  with torch.no_grad(): 
    processed_sentence = '' + sentence + ''
    tokenized = tokenizer.encode(processed_sentence)
    input_ids = torch.tensor(tokenized).unsqueeze(0)  # Batch size 1
    outputs = model(input_ids)
    index_of_mask = tokenized.index(id_of_mask)

    # batch, tokens, vocab_size
    prediction_scores = outputs[0]

    return prediction_scores[0][index_of_mask].cpu().numpy().tolist()



import os
import shutil

# Free up memory 
if os.environ.get('REMOVE_WEIGHTS') == 'TRUE':
  print('removing zari-bert-cda from filesystem')
  shutil.rmtree('zari-bert-cda', ignore_errors=True)