Spaces:
Runtime error
Runtime error
File size: 1,279 Bytes
bfc0ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
"""Utils for transformer embeddings."""
import functools
import os
from typing import TYPE_CHECKING, Optional
from ..env import data_path
from ..utils import log
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
def get_model(model_name: str,
optimal_batch_sizes: dict[str, int] = {}) -> tuple[int, 'SentenceTransformer']:
"""Get a transformer model and the optimal batch size for it."""
try:
import torch.backends.mps
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError('Could not import the "sentence_transformers" python package. '
'Please install it with `pip install sentence-transformers`.')
preferred_device: Optional[str] = None
if torch.backends.mps.is_available():
preferred_device = 'mps'
elif not torch.backends.mps.is_built():
log('MPS not available because the current PyTorch install was not built with MPS enabled.')
@functools.cache
def _get_model(model_name: str) -> 'SentenceTransformer':
return SentenceTransformer(
model_name, device=preferred_device, cache_folder=os.path.join(data_path(), '.cache'))
batch_size = optimal_batch_sizes[preferred_device or '']
return batch_size, _get_model(model_name)
|