vamsibanda
commited on
Commit
•
eb644f5
1
Parent(s):
10d6ae9
Update README.md
Browse files
README.md
CHANGED
@@ -37,6 +37,7 @@ import torch
|
|
37 |
from transformers.modeling_outputs import BaseModelOutput
|
38 |
from transformers import T5TokenizerFast
|
39 |
import torch.nn.functional as F
|
|
|
40 |
|
41 |
model_name = 'vamsibanda/sbert-onnx-all-roberta-large-v1'
|
42 |
cache_folder = './'
|
@@ -49,7 +50,7 @@ def mean_pooling(model_output, attention_mask):
|
|
49 |
|
50 |
def download_onnx_model(model_name, cache_folder, model_path, force_download = False):
|
51 |
if force_download and os.path.exists(model_path):
|
52 |
-
|
53 |
|
54 |
snapshot_download(model_name,
|
55 |
cache_dir=cache_folder,
|
|
|
37 |
from transformers.modeling_outputs import BaseModelOutput
|
38 |
from transformers import T5TokenizerFast
|
39 |
import torch.nn.functional as F
|
40 |
+
import shutil
|
41 |
|
42 |
model_name = 'vamsibanda/sbert-onnx-all-roberta-large-v1'
|
43 |
cache_folder = './'
|
|
|
50 |
|
51 |
def download_onnx_model(model_name, cache_folder, model_path, force_download = False):
|
52 |
if force_download and os.path.exists(model_path):
|
53 |
+
shutil.rmtree(model_path)
|
54 |
|
55 |
snapshot_download(model_name,
|
56 |
cache_dir=cache_folder,
|