File size: 5,644 Bytes
99ced83
15bf463
b1dd47e
dbafbbf
b1dd47e
15bf463
b1dd47e
 
af40811
 
 
 
 
 
 
 
45e9ce6
 
 
4267129
 
15bf463
 
dbafbbf
99ced83
15bf463
b1dd47e
15bf463
 
 
99ced83
 
 
 
 
4267129
 
99ced83
dbafbbf
99ced83
dbafbbf
 
4267129
dbafbbf
 
4267129
99ced83
dbafbbf
 
 
 
 
 
99ced83
dbafbbf
 
 
 
99ced83
dbafbbf
 
 
 
 
 
 
99ced83
dbafbbf
 
99ced83
 
b1dd47e
 
 
 
 
 
 
 
da1451c
 
 
 
 
 
99ced83
dbafbbf
 
45e9ce6
 
 
dbafbbf
 
 
 
 
 
da1451c
dbafbbf
 
da1451c
dbafbbf
 
 
 
da1451c
 
 
 
 
99ced83
dbafbbf
 
 
da1451c
 
 
 
 
 
dbafbbf
b1dd47e
 
da1451c
 
 
 
b1dd47e
45e9ce6
b1dd47e
 
 
da1451c
 
 
 
 
 
 
 
4267129
 
da1451c
 
 
 
 
dbafbbf
 
 
45e9ce6
dbafbbf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import Generator, Set, Union, List, Optional

import requests
from bs4 import BeautifulSoup, Tag, NavigableString, PageElement
from concurrent.futures import ThreadPoolExecutor, as_completed

SUPPORTED_MODEL_NAME_PAGES_FORMAT = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch"
MAX_WORKERS = 10
BLACKLISTED_MODEL_NAMES = {
    "ykilcher/gpt-4chan",
    "bigscience/mt0-xxl",
    "bigscience/mt0-xl",
    "bigscience/mt0-large",
    "bigscience/mt0-base",
    "bigscience/mt0-small",
}
BLACKLISTED_ORGANIZATIONS = {
    "huggingtweets"
}
DEFAULT_MIN_NUMBER_OF_DOWNLOADS = 100
DEFAULT_MIN_NUMBER_OF_LIKES = 20


def get_model_name(model_card: Tag) -> str:
    """returns the model name from the model card tag"""
    h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600"
    h4_tag = model_card.find("h4", class_=h4_class)
    return h4_tag.text


def is_a_number(element: Union[PageElement, Tag]) -> bool:
    """returns True if the element is a number, False otherwise"""
    if isinstance(element, Tag):
        return False
    text = element.text
    lowered_text = text.strip().lower()
    no_characters_text = lowered_text.replace("k", "").replace("m", "").replace("b", "")
    element = no_characters_text.replace(",", "").replace(".", "")
    try:
        float(element)
    except ValueError:
        return False
    return True


def get_numeric_contents(model_card: Tag) -> List[PageElement]:
    """returns the number of likes and downloads from the model card tag it they exist in the model card"""
    div: Union[Tag | NavigableString] = model_card.find(
        "div",
        class_="mr-1 flex items-center overflow-hidden whitespace-nowrap text-sm leading-tight text-gray-400",
        recursive=True
    )
    contents: List[PageElement] = div.contents
    number_contents: List[PageElement] = [content for content in contents if is_a_number(content)]
    return number_contents


def convert_to_int(element: PageElement) -> int:
    """converts the element to an int"""
    element_str = element.text.strip().lower()
    if element_str.endswith("k"):
        return int(float(element_str[:-1]) * 1_000)
    elif element_str.endswith("m"):
        return int(float(element_str[:-1]) * 1_000_000)
    elif element_str.endswith("b"):
        return int(float(element_str[:-1]) * 1_000_000_000)
    return int(element_str)


def get_page(page_index: int) -> Optional[BeautifulSoup]:
    """returns the page with the given index if it exists, None otherwise"""
    curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={page_index}"
    response = requests.get(curr_page_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.content, "html.parser")
        return soup
    return None


def card_filter(
        model_card: Tag,
        model_name: str,
        min_number_of_downloads: int,
        min_number_of_likes: int,
) -> bool:
    """returns True if the model card is valid, False otherwise"""
    if model_name in BLACKLISTED_MODEL_NAMES:
        return False
    organization = model_name.split("/")[0]
    if organization in BLACKLISTED_ORGANIZATIONS:
        return False
    numeric_contents = get_numeric_contents(model_card)
    if len(numeric_contents) < 2:
        # If the model card doesn't have at least 2 numeric contents,
        # It means that he doesn't have any downloads/likes, so it's not a valid model card.
        return False
    number_of_downloads = convert_to_int(numeric_contents[0])
    if number_of_downloads < min_number_of_downloads:
        return False
    number_of_likes = convert_to_int(numeric_contents[1])
    if number_of_likes < min_number_of_likes:
        return False
    return True


def get_model_names(
        soup: BeautifulSoup,
        min_number_of_downloads: int,
        min_number_of_likes: int,
) -> Generator[str, None, None]:
    """Scrapes the model names from the given soup"""
    model_cards: List[Tag] = soup.find_all("article", class_="overview-card-wrapper group", recursive=True)
    for model_card in model_cards:
        model_name = get_model_name(model_card)
        if card_filter(
                model_card=model_card,
                model_name=model_name,
                min_number_of_downloads=min_number_of_downloads,
                min_number_of_likes=min_number_of_likes
        ):
            yield model_name


def generate_supported_model_names(
        min_number_of_downloads: int,
        min_number_of_likes: int,
) -> Generator[str, None, None]:
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_index = {executor.submit(get_page, index): index for index in range(300)}
        for future in as_completed(future_to_index):
            soup = future.result()
            if soup:
                yield from get_model_names(
                    soup=soup,
                    min_number_of_downloads=min_number_of_downloads,
                    min_number_of_likes=min_number_of_likes,
                )


def get_supported_model_names(
        min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
        min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
) -> Set[str]:
    return set(generate_supported_model_names(
        min_number_of_downloads=min_number_of_downloads,
        min_number_of_likes=min_number_of_likes,
    ))


if __name__ == "__main__":
    supported_model_names = get_supported_model_names(1, 1)
    print(f"Number of supported model names: {len(supported_model_names)}")
    print(f"Supported model names: {supported_model_names}")