Spaces:
Sleeping
Sleeping
File size: 1,830 Bytes
f788018 88cd3e7 f788018 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
def list_uniq(l):
return sorted(set(l), key=l.index)
def get_status(model_name: str):
from huggingface_hub import InferenceClient
client = InferenceClient(timeout=10)
return client.get_model_status(model_name)
def is_loadable(model_name: str, force_gpu: bool = False):
try:
status = get_status(model_name)
except Exception as e:
print(e)
print(f"Couldn't load {model_name}.")
return False
gpu_state = isinstance(status.compute_type, dict) and "gpu" in status.compute_type.keys()
if status is None or status.state not in ["Loadable", "Loaded"] or (force_gpu and not gpu_state):
print(f"Couldn't load {model_name}. Model state:'{status.state}', GPU:{gpu_state}")
return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=True):
from huggingface_hub import HfApi
api = HfApi()
#default_tags = ["transformers"]
default_tags = []
if not sort: sort = "last_modified"
models = []
limit = limit * 20 if force_gpu else limit * 5
try:
model_infos = api.list_models(author=author, pipeline_tag="text-generation",
tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
except Exception as e:
print(f"Error: Failed to list models.")
print(e)
return models
for model in model_infos:
if not model.private and not model.gated:
if not_tag and not_tag in model.tags or not is_loadable(model.id, force_gpu): continue
models.append(model.id)
if len(models) == limit: break
return models
|