jxtan commited on
Commit
355fbaf
1 Parent(s): 12512d7

Update sentence_embeddings.py

Browse files
Files changed (1) hide show
  1. sentence_embeddings.py +3 -2
sentence_embeddings.py CHANGED
@@ -7,6 +7,7 @@ from datetime import datetime
7
  from logger import log
8
  from config import TEST_MODE
9
 
 
10
  router = APIRouter()
11
 
12
  class SentenceEmbeddingsInput(BaseModel):
@@ -59,11 +60,11 @@ def generic_sentence_embeddings(model_name: str):
59
  tokenizer, model = loaded_models[model_name]
60
  else:
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
- model = AutoModel.from_pretrained(model_name)
63
  loaded_models[model] = (tokenizer, model)
64
 
65
  # Tokenize sentences
66
- encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
67
  with torch.no_grad():
68
  model_output = model(**encoded_input)
69
  sentence_embeddings = model_output[0][:, 0]
 
7
  from logger import log
8
  from config import TEST_MODE
9
 
10
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
  router = APIRouter()
12
 
13
  class SentenceEmbeddingsInput(BaseModel):
 
60
  tokenizer, model = loaded_models[model_name]
61
  else:
62
  tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ model = AutoModel.from_pretrained(model_name).to(device)
64
  loaded_models[model] = (tokenizer, model)
65
 
66
  # Tokenize sentences
67
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
68
  with torch.no_grad():
69
  model_output = model(**encoded_input)
70
  sentence_embeddings = model_output[0][:, 0]