mlabonne commited on
Commit
47913ac
1 Parent(s): 1b3a43e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -24
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
 
 
3
  from huggingface_hub import ModelCard, HfApi
4
  import requests
5
  import networkx as nx
@@ -20,6 +21,23 @@ TITLE = """
20
  """
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def get_model_names_from_yaml(url):
24
  """Get a list of parent model names from the yaml file."""
25
  model_tags = []
@@ -32,7 +50,7 @@ def get_model_names_from_yaml(url):
32
  def get_license_color(model):
33
  """Get the color of the model based on its license."""
34
  try:
35
- card = ModelCard.load(model)
36
  license = card.data.to_dict()['license'].lower()
37
  # Define permissive licenses
38
  permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail'] # Add more as needed
@@ -46,47 +64,45 @@ def get_license_color(model):
46
  return 'lightgray'
47
 
48
 
49
- def get_model_names(model, genealogy, found_models=None):
50
- """Get a list of parent model names from the model id."""
51
- model_tags = []
52
-
53
  if found_models is None:
54
- found_models = []
 
 
 
 
 
 
 
55
 
56
  try:
57
- card = ModelCard.load(model)
58
- card_dict = card.data.to_dict() # Convert the ModelCard object to a dictionary
59
  license = card_dict['license']
60
 
61
- # Check the base_model in metadata
62
  if 'base_model' in card_dict:
63
  model_tags = card_dict['base_model']
64
 
65
- # Check the tags in metadata
66
  if 'tags' in card_dict and not model_tags:
67
  tags = card_dict['tags']
68
  model_tags = [model_name for model_name in tags if '/' in model_name]
69
 
70
- # Check for merge.yml and mergekit_config.yml if no model_tags found in the tags
71
  if not model_tags:
72
  model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/merge.yml"))
73
  if not model_tags:
74
  model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/mergekit_config.yml"))
75
 
76
- # Convert to a list if tags is not None or empty, else set to an empty list
77
  if not isinstance(model_tags, list):
78
  model_tags = [model_tags] if model_tags else []
79
 
80
- # Add found model names to the list
81
- found_models.extend(model_tags)
82
 
83
- # Record the genealogy
84
  for model_tag in model_tags:
85
  genealogy[model_tag].append(model)
86
-
87
- # Recursively check for more models
88
- for model_tag in model_tags:
89
- get_model_names(model_tag, genealogy, found_models)
90
 
91
  except Exception as e:
92
  print(f"Could not find model names for {model}: {e}")
@@ -125,6 +141,8 @@ def create_family_tree(start_model):
125
  genealogy = defaultdict(list)
126
  get_model_names(start_model, genealogy) # Assuming this populates the genealogy
127
 
 
 
128
  # Create a directed graph
129
  G = nx.DiGraph()
130
 
@@ -133,11 +151,14 @@ def create_family_tree(start_model):
133
  for child in children:
134
  G.add_edge(parent, child)
135
 
136
- # Get max depth
137
- max_depth = nx.dag_longest_path_length(G) + 1
138
-
139
- # Get max width
140
- max_width = max_width_of_tree(G) + 1
 
 
 
141
 
142
  # Estimate plot size
143
  height = max(8, 1.6 * max_depth)
@@ -149,6 +170,7 @@ def create_family_tree(start_model):
149
 
150
  # Determine node colors based on license
151
  node_colors = [get_license_color(node) for node in G.nodes()]
 
152
 
153
  # Create a label mapping with line breaks
154
  labels = {node: node.replace("/", "\n") for node in G.nodes()}
 
1
  import gradio as gr
2
 
3
+ import sys
4
  from huggingface_hub import ModelCard, HfApi
5
  import requests
6
  import networkx as nx
 
21
  """
22
 
23
 
24
+ # We should first try to cache models
25
+ class CachedModelCard(ModelCard):
26
+ _cache = {}
27
+
28
+ @classmethod
29
+ def load(cls, model_id: str, **kwargs) -> "ModelCard":
30
+ if model_id not in cls._cache:
31
+ try:
32
+ print('REQUEST ModelCard:', model_id)
33
+ cls._cache[model_id] = super().load(model_id, **kwargs)
34
+ except:
35
+ cls._cache[model_id] = None
36
+ else:
37
+ print('CACHED:', model_id)
38
+ return cls._cache[model_id]
39
+
40
+
41
  def get_model_names_from_yaml(url):
42
  """Get a list of parent model names from the yaml file."""
43
  model_tags = []
 
50
  def get_license_color(model):
51
  """Get the color of the model based on its license."""
52
  try:
53
+ card = CachedModelCard.load(model)
54
  license = card.data.to_dict()['license'].lower()
55
  # Define permissive licenses
56
  permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail'] # Add more as needed
 
64
  return 'lightgray'
65
 
66
 
67
+ def get_model_names(model, genealogy, found_models=None, visited_models=None):
68
+ print('---')
69
+ print(model)
 
70
  if found_models is None:
71
+ found_models = set()
72
+ if visited_models is None:
73
+ visited_models = set()
74
+
75
+ if model in visited_models:
76
+ print("Model already visited...")
77
+ return found_models
78
+ visited_models.add(model)
79
 
80
  try:
81
+ card = CachedModelCard.load(model)
82
+ card_dict = card.data.to_dict()
83
  license = card_dict['license']
84
 
85
+ model_tags = []
86
  if 'base_model' in card_dict:
87
  model_tags = card_dict['base_model']
88
 
 
89
  if 'tags' in card_dict and not model_tags:
90
  tags = card_dict['tags']
91
  model_tags = [model_name for model_name in tags if '/' in model_name]
92
 
 
93
  if not model_tags:
94
  model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/merge.yml"))
95
  if not model_tags:
96
  model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/mergekit_config.yml"))
97
 
 
98
  if not isinstance(model_tags, list):
99
  model_tags = [model_tags] if model_tags else []
100
 
101
+ found_models.add(model)
 
102
 
 
103
  for model_tag in model_tags:
104
  genealogy[model_tag].append(model)
105
+ get_model_names(model_tag, genealogy, found_models, visited_models)
 
 
 
106
 
107
  except Exception as e:
108
  print(f"Could not find model names for {model}: {e}")
 
141
  genealogy = defaultdict(list)
142
  get_model_names(start_model, genealogy) # Assuming this populates the genealogy
143
 
144
+ print("Number of models:", len(CachedModelCard._cache))
145
+
146
  # Create a directed graph
147
  G = nx.DiGraph()
148
 
 
151
  for child in children:
152
  G.add_edge(parent, child)
153
 
154
+ try:
155
+ # Get max depth and width
156
+ max_depth = nx.dag_longest_path_length(G) + 1
157
+ max_width = max_width_of_tree(G) + 1
158
+ except:
159
+ # Get max depth and width
160
+ max_depth = 21
161
+ max_width = 9
162
 
163
  # Estimate plot size
164
  height = max(8, 1.6 * max_depth)
 
170
 
171
  # Determine node colors based on license
172
  node_colors = [get_license_color(node) for node in G.nodes()]
173
+ clear_output()
174
 
175
  # Create a label mapping with line breaks
176
  labels = {node: node.replace("/", "\n") for node in G.nodes()}