Spaces:
Running
Running
Update sentence_embeddings.py
Browse files- sentence_embeddings.py +3 -2
sentence_embeddings.py
CHANGED
@@ -7,6 +7,7 @@ from datetime import datetime
|
|
7 |
from logger import log
|
8 |
from config import TEST_MODE
|
9 |
|
|
|
10 |
router = APIRouter()
|
11 |
|
12 |
class SentenceEmbeddingsInput(BaseModel):
|
@@ -59,11 +60,11 @@ def generic_sentence_embeddings(model_name: str):
|
|
59 |
tokenizer, model = loaded_models[model_name]
|
60 |
else:
|
61 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
62 |
-
model = AutoModel.from_pretrained(model_name)
|
63 |
loaded_models[model] = (tokenizer, model)
|
64 |
|
65 |
# Tokenize sentences
|
66 |
-
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
|
67 |
with torch.no_grad():
|
68 |
model_output = model(**encoded_input)
|
69 |
sentence_embeddings = model_output[0][:, 0]
|
|
|
7 |
from logger import log
|
8 |
from config import TEST_MODE
|
9 |
|
10 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
11 |
router = APIRouter()
|
12 |
|
13 |
class SentenceEmbeddingsInput(BaseModel):
|
|
|
60 |
tokenizer, model = loaded_models[model_name]
|
61 |
else:
|
62 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
63 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
64 |
loaded_models[model] = (tokenizer, model)
|
65 |
|
66 |
# Tokenize sentences
|
67 |
+
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
|
68 |
with torch.no_grad():
|
69 |
model_output = model(**encoded_input)
|
70 |
sentence_embeddings = model_output[0][:, 0]
|