CultriX commited on
Commit
490def8
1 Parent(s): 66f98b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -97
app.py CHANGED
@@ -1,5 +1,5 @@
 
1
  import gradio as gr
2
-
3
  import sys
4
  from huggingface_hub import ModelCard, HfApi
5
  import requests
@@ -11,17 +11,10 @@ from networkx.drawing.nx_pydot import graphviz_layout
11
  from io import BytesIO
12
  from PIL import Image
13
 
 
 
14
 
15
- TITLE = """
16
- <div align="center">
17
- <p style="font-size: 36px;">🌳 Model Family Tree</p>
18
- </div><br/>
19
- <p>Automatically calculate the <strong>family tree of a given model</strong>. It also displays the type of license each model uses (permissive, noncommercial, or unknown).</p>
20
- <p>You can also run the code in this <a href="https://colab.research.google.com/drive/1s2eQlolcI1VGgDhqWIANfkfKvcKrMyNr?usp=sharing">Colab notebook</a>. Special thanks to <a href="https://huggingface.co/leonardlin">leonardlin</a> for his caching implementation. See also mrfakename's version in <a href="https://huggingface.co/spaces/mrfakename/merge-model-tree">this space</a>.</p>
21
- """
22
-
23
-
24
- # We should first try to cache models
25
  class CachedModelCard(ModelCard):
26
  _cache = {}
27
 
@@ -29,51 +22,40 @@ class CachedModelCard(ModelCard):
29
  def load(cls, model_id: str, **kwargs) -> "ModelCard":
30
  if model_id not in cls._cache:
31
  try:
32
- print('REQUEST ModelCard:', model_id)
33
  cls._cache[model_id] = super().load(model_id, **kwargs)
34
  except:
35
  cls._cache[model_id] = None
36
- else:
37
- print('CACHED:', model_id)
38
  return cls._cache[model_id]
39
 
40
-
41
  def get_model_names_from_yaml(url):
42
- """Get a list of parent model names from the yaml file."""
43
  model_tags = []
44
  response = requests.get(url)
45
  if response.status_code == 200:
46
  model_tags.extend([item for item in response.content if '/' in str(item)])
47
  return model_tags
48
 
49
-
50
  def get_license_color(model):
51
- """Get the color of the model based on its license."""
52
  try:
53
  card = CachedModelCard.load(model)
54
  license = card.data.to_dict()['license'].lower()
55
- # Define permissive licenses
56
- permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail'] # Add more as needed
57
- # Check license type
58
  if any(perm_license in license for perm_license in permissive_licenses):
59
- return 'lightgreen' # Permissive licenses
60
  else:
61
- return 'lightcoral' # Noncommercial or other licenses
62
  except Exception as e:
63
- print(f"Error retrieving license for {model}: {e}")
64
  return 'lightgray'
65
 
66
-
67
  def get_model_names(model, genealogy, found_models=None, visited_models=None):
68
- print('---')
69
- print(model)
70
  if found_models is None:
71
  found_models = set()
72
  if visited_models is None:
73
  visited_models = set()
74
 
75
  if model in visited_models:
76
- print("Model already visited...")
77
  return found_models
78
  visited_models.add(model)
79
 
@@ -105,79 +87,37 @@ def get_model_names(model, genealogy, found_models=None, visited_models=None):
105
  get_model_names(model_tag, genealogy, found_models, visited_models)
106
 
107
  except Exception as e:
108
- print(f"Could not find model names for {model}: {e}")
109
 
110
  return found_models
111
 
112
-
113
- def find_root_nodes(G):
114
- """ Find all nodes in the graph with no predecessors """
115
- return [n for n, d in G.in_degree() if d == 0]
116
-
117
-
118
- def max_width_of_tree(G):
119
- """ Calculate the maximum width of the tree """
120
- max_width = 0
121
- for root in find_root_nodes(G):
122
- width_at_depth = calculate_width_at_depth(G, root)
123
- local_max_width = max(width_at_depth.values())
124
- max_width = max(max_width, local_max_width)
125
- return max_width
126
-
127
-
128
- def calculate_width_at_depth(G, root):
129
- """ Calculate width at each depth starting from a given root """
130
- depth_count = defaultdict(int)
131
- queue = [(root, 0)]
132
- while queue:
133
- node, depth = queue.pop(0)
134
- depth_count[depth] += 1
135
- for child in G.successors(node):
136
- queue.append((child, depth + 1))
137
- return depth_count
138
-
139
-
140
  def create_family_tree(start_model):
141
  genealogy = defaultdict(list)
142
- get_model_names(start_model, genealogy) # Assuming this populates the genealogy
143
 
144
- print("Number of models:", len(CachedModelCard._cache))
145
-
146
- # Create a directed graph
147
  G = nx.DiGraph()
148
 
149
- # Add nodes and edges to the graph
150
  for parent, children in genealogy.items():
151
  for child in children:
152
  G.add_edge(parent, child)
153
 
154
- try:
155
- # Get max depth and width
156
- max_depth = nx.dag_longest_path_length(G) + 1
157
- max_width = max_width_of_tree(G) + 1
158
- except:
159
- # Get max depth and width
160
- max_depth = 21
161
- max_width = 9
162
-
163
- # Estimate plot size
164
  height = max(8, 1.6 * max_depth)
165
  width = max(8, 6 * max_width)
166
 
167
- # Set Graphviz layout attributes for a bottom-up tree
168
  plt.figure(figsize=(width, height))
169
  pos = graphviz_layout(G, prog="dot")
170
 
171
- # Determine node colors based on license
172
  node_colors = [get_license_color(node) for node in G.nodes()]
 
173
 
174
- # Create a label mapping with line breaks
175
  labels = {node: node.replace("/", "\n") for node in G.nodes()}
176
 
177
- # Draw the graph
178
  nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')
179
 
180
- # Create a legend for the colors
181
  legend_elements = [
182
  Patch(facecolor='lightgreen', label='Permissive'),
183
  Patch(facecolor='lightcoral', label='Noncommercial'),
@@ -186,23 +126,6 @@ def create_family_tree(start_model):
186
  plt.legend(handles=legend_elements, loc='upper left')
187
 
188
  plt.title(f"{start_model}'s Family Tree", fontsize=20)
189
-
190
- # Capture the plot as an image in memory
191
- img_buffer = BytesIO()
192
- plt.savefig(img_buffer, format='png', bbox_inches='tight')
193
- plt.close()
194
- img_buffer.seek(0)
195
-
196
- # Open the image using PIL
197
- img = Image.open(img_buffer)
198
-
199
- return img
200
-
201
- with gr.Blocks() as demo:
202
- gr.Markdown(TITLE)
203
- model_id = gr.Textbox(label="Model ID", value="mlabonne/NeuralBeagle14-7B")
204
- btn = gr.Button("Create tree")
205
- out = gr.Image()
206
- btn.click(fn=create_family_tree, inputs=model_id, outputs=out)
207
-
208
- demo.queue(api_open=False).launch(show_api=False)
 
1
+ # Import necessary libraries
2
  import gradio as gr
 
3
  import sys
4
  from huggingface_hub import ModelCard, HfApi
5
  import requests
 
11
  from io import BytesIO
12
  from PIL import Image
13
 
14
+ # Define the model ID
15
+ MODEL_ID = "mlabonne/NeuralBeagle14-7B"
16
 
17
+ # Define a class to cache model cards
 
 
 
 
 
 
 
 
 
18
  class CachedModelCard(ModelCard):
19
  _cache = {}
20
 
 
22
  def load(cls, model_id: str, **kwargs) -> "ModelCard":
23
  if model_id not in cls._cache:
24
  try:
 
25
  cls._cache[model_id] = super().load(model_id, **kwargs)
26
  except:
27
  cls._cache[model_id] = None
 
 
28
  return cls._cache[model_id]
29
 
30
+ # Function to get model names from a YAML file
31
  def get_model_names_from_yaml(url):
 
32
  model_tags = []
33
  response = requests.get(url)
34
  if response.status_code == 200:
35
  model_tags.extend([item for item in response.content if '/' in str(item)])
36
  return model_tags
37
 
38
+ # Function to get the color of the model based on its license
39
  def get_license_color(model):
 
40
  try:
41
  card = CachedModelCard.load(model)
42
  license = card.data.to_dict()['license'].lower()
43
+ permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail']
 
 
44
  if any(perm_license in license for perm_license in permissive_licenses):
45
+ return 'lightgreen'
46
  else:
47
+ return 'lightcoral'
48
  except Exception as e:
 
49
  return 'lightgray'
50
 
51
+ # Function to find model names in the family tree
52
  def get_model_names(model, genealogy, found_models=None, visited_models=None):
 
 
53
  if found_models is None:
54
  found_models = set()
55
  if visited_models is None:
56
  visited_models = set()
57
 
58
  if model in visited_models:
 
59
  return found_models
60
  visited_models.add(model)
61
 
 
87
  get_model_names(model_tag, genealogy, found_models, visited_models)
88
 
89
  except Exception as e:
90
+ pass
91
 
92
  return found_models
93
 
94
+ # Function to create the family tree
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def create_family_tree(start_model):
96
  genealogy = defaultdict(list)
97
+ get_model_names(start_model, genealogy)
98
 
 
 
 
99
  G = nx.DiGraph()
100
 
 
101
  for parent, children in genealogy.items():
102
  for child in children:
103
  G.add_edge(parent, child)
104
 
105
+ max_depth = nx.dag_longest_path_length(G) + 1
106
+ max_width = max_width_of_tree(G) + 1
107
+
 
 
 
 
 
 
 
108
  height = max(8, 1.6 * max_depth)
109
  width = max(8, 6 * max_width)
110
 
 
111
  plt.figure(figsize=(width, height))
112
  pos = graphviz_layout(G, prog="dot")
113
 
 
114
  node_colors = [get_license_color(node) for node in G.nodes()]
115
+ clear_output()
116
 
 
117
  labels = {node: node.replace("/", "\n") for node in G.nodes()}
118
 
 
119
  nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')
120
 
 
121
  legend_elements = [
122
  Patch(facecolor='lightgreen', label='Permissive'),
123
  Patch(facecolor='lightcoral', label='Noncommercial'),
 
126
  plt.legend(handles=legend_elements, loc='upper left')
127
 
128
  plt.title(f"{start_model}'s Family Tree", fontsize=20)
129
+ plt.show()
130
+
131
+ create_family_tree(MODEL_ID)