vamsibanda commited on
Commit
2c12b75
1 Parent(s): c07c56a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -12
README.md CHANGED
@@ -41,18 +41,6 @@ model_name = 'vamsibanda/sbert-onnx-all-roberta-large-v1'
41
  cache_folder = './'
42
  model_path = os.path.join(cache_folder, model_name.replace("/", "_"))
43
 
44
- def generate_embedding(text):
45
- token = tokenizer(text, return_tensors='pt')
46
- embeddings = model(input_ids=token['input_ids'], attention_mask=token['attention_mask'])
47
- sbert_embeddings = mean_pooling(embeddings, token['attention_mask'])
48
- sbert_embeddings = F.normalize(sbert_embeddings, p=2, dim=1)
49
- return sbert_embeddings.tolist()[0]
50
-
51
- def mean_pooling(model_output, attention_mask):
52
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
53
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
54
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
55
-
56
  def download_onnx_model(model_name, cache_folder, model_path, force_download = False):
57
  if force_download and os.path.exists(model_path):
58
  shutil.rmtree(model_path)
@@ -63,6 +51,18 @@ def download_onnx_model(model_name, cache_folder, model_path, force_download = F
63
  library_name='sentence-transformers'
64
  )
65
  return
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  _ = download_onnx_model(model_name, cache_folder, model_path)
 
41
  cache_folder = './'
42
  model_path = os.path.join(cache_folder, model_name.replace("/", "_"))
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def download_onnx_model(model_name, cache_folder, model_path, force_download = False):
45
  if force_download and os.path.exists(model_path):
46
  shutil.rmtree(model_path)
 
51
  library_name='sentence-transformers'
52
  )
53
  return
54
+
55
+ def mean_pooling(model_output, attention_mask):
56
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
57
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
58
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
59
+
60
+ def generate_embedding(text):
61
+ token = tokenizer(text, return_tensors='pt')
62
+ embedding = model(input_ids=token['input_ids'], attention_mask=token['attention_mask'])
63
+ embedding = mean_pooling(embedding, token['attention_mask'])
64
+ embedding = F.normalize(embedding, p=2, dim=1)
65
+ return embedding.tolist()[0]
66
 
67
 
68
  _ = download_onnx_model(model_name, cache_folder, model_path)