Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import asyncio | |
import gradio as gr | |
import src.constants as constants | |
from src.hub import list_models, load_model_card | |
async def load_model_tree(result_paths_per_model, model_ids): | |
# TODO: Multiple models? | |
model_id = model_ids[0] | |
model_tree = await asyncio.gather( | |
load_base_models(model_id), | |
*[ | |
load_derived_models_by_type(model_id, derived_model_type[1]) | |
for derived_model_type in constants.DERIVED_MODEL_TYPES | |
], | |
) | |
model_tree_choices = [ | |
[model_id for model_id in model_ids if model_id in result_paths_per_model] for model_ids in model_tree | |
] | |
model_tree_labels = [constants.BASE_MODEL_TYPE[0]] + [ | |
derived_model_type[0] for derived_model_type in constants.DERIVED_MODEL_TYPES | |
] | |
return [ | |
gr.Dropdown(choices=choices, label=f"{label} ({len(choices)})", interactive=True if choices else False) | |
for choices, label in zip(model_tree_choices, model_tree_labels) | |
] | |
async def load_base_models(model_id) -> list[str]: | |
card = await load_model_card(model_id) | |
if not card: | |
return [] | |
base_models = getattr(card.data, constants.BASE_MODEL_TYPE[1]) | |
if not isinstance(base_models, list): | |
base_models = [base_models] | |
return base_models | |
async def load_derived_models_by_type(model_id, derived_model_type) -> list[str]: | |
models = await list_models(filtering=f"base_model:{derived_model_type}:{model_id}") | |
if not models: | |
return [] | |
models = [model["id"] for model in models] | |
return models | |