File size: 1,817 Bytes
460d762
 
 
 
d52179b
460d762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d52179b
460d762
 
 
 
 
 
d52179b
460d762
 
 
 
 
d52179b
460d762
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import re
from typing import List

from src.utils_display import AutoEvalColumn
from src.auto_leaderboard.model_metadata_type import get_model_type

from huggingface_hub import HfApi
import huggingface_hub
api = HfApi()


def get_model_infos_from_hub(leaderboard_data: List[dict]):
    for model_data in leaderboard_data:
        model_name = model_data["model_name_for_query"]
        try:
            model_info = api.model_info(model_name)
        except huggingface_hub.utils._errors.RepositoryNotFoundError:
            model_data[AutoEvalColumn.license.name] = None
            model_data[AutoEvalColumn.likes.name] = None
            model_data[AutoEvalColumn.params.name] = None
            continue

        model_data[AutoEvalColumn.license.name] = get_model_license(model_info)
        model_data[AutoEvalColumn.likes.name] = get_model_likes(model_info)
        model_data[AutoEvalColumn.params.name] = get_model_size(model_name, model_info)


def get_model_license(model_info):
    try:
        return model_info.cardData["license"]
    except Exception:
        return None

def get_model_likes(model_info):
    return model_info.likes

size_pattern = re.compile(r"\d+(b|m)")

def get_model_size(model_name, model_info):
    # In billions
    try:
        return round(model_info.safetensors["total"] / 1e9, 3) 
    except AttributeError:
        #print(f"Repository {model_id} does not have safetensors weights")
        pass
    try:
        size_match = re.search(size_pattern, model_name.lower())
        size = size_match.group(0)
        return round(int(size[:-1]) if size[-1] == "b" else int(size[:-1]) / 1e3, 3)
    except AttributeError:
        return None


def apply_metadata(leaderboard_data: List[dict]):
    get_model_type(leaderboard_data)
    get_model_infos_from_hub(leaderboard_data)