CultriX commited on
Commit
5f8cec7
1 Parent(s): 490def8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -11
app.py CHANGED
@@ -11,51 +11,72 @@ from networkx.drawing.nx_pydot import graphviz_layout
11
  from io import BytesIO
12
  from PIL import Image
13
 
 
 
 
 
 
 
 
 
14
  # Define the model ID
15
  MODEL_ID = "mlabonne/NeuralBeagle14-7B"
16
 
17
  # Define a class to cache model cards
18
  class CachedModelCard(ModelCard):
19
  _cache = {}
20
-
21
  @classmethod
22
  def load(cls, model_id: str, **kwargs) -> "ModelCard":
23
  if model_id not in cls._cache:
24
  try:
 
25
  cls._cache[model_id] = super().load(model_id, **kwargs)
26
  except:
27
  cls._cache[model_id] = None
 
 
28
  return cls._cache[model_id]
29
 
30
  # Function to get model names from a YAML file
31
  def get_model_names_from_yaml(url):
 
32
  model_tags = []
33
  response = requests.get(url)
34
  if response.status_code == 200:
35
  model_tags.extend([item for item in response.content if '/' in str(item)])
36
  return model_tags
37
 
 
38
  # Function to get the color of the model based on its license
39
  def get_license_color(model):
 
40
  try:
41
  card = CachedModelCard.load(model)
42
  license = card.data.to_dict()['license'].lower()
43
- permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail']
 
 
44
  if any(perm_license in license for perm_license in permissive_licenses):
45
- return 'lightgreen'
46
  else:
47
- return 'lightcoral'
48
  except Exception as e:
 
49
  return 'lightgray'
50
 
 
51
  # Function to find model names in the family tree
52
  def get_model_names(model, genealogy, found_models=None, visited_models=None):
 
 
53
  if found_models is None:
54
  found_models = set()
55
  if visited_models is None:
56
  visited_models = set()
57
 
58
  if model in visited_models:
 
59
  return found_models
60
  visited_models.add(model)
61
 
@@ -87,37 +108,79 @@ def get_model_names(model, genealogy, found_models=None, visited_models=None):
87
  get_model_names(model_tag, genealogy, found_models, visited_models)
88
 
89
  except Exception as e:
90
- pass
91
 
92
  return found_models
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Function to create the family tree
95
  def create_family_tree(start_model):
96
  genealogy = defaultdict(list)
97
- get_model_names(start_model, genealogy)
 
 
98
 
 
99
  G = nx.DiGraph()
100
 
 
101
  for parent, children in genealogy.items():
102
  for child in children:
103
  G.add_edge(parent, child)
104
 
105
- max_depth = nx.dag_longest_path_length(G) + 1
106
- max_width = max_width_of_tree(G) + 1
107
-
 
 
 
 
 
 
 
108
  height = max(8, 1.6 * max_depth)
109
  width = max(8, 6 * max_width)
110
 
 
111
  plt.figure(figsize=(width, height))
112
  pos = graphviz_layout(G, prog="dot")
113
 
 
114
  node_colors = [get_license_color(node) for node in G.nodes()]
115
- clear_output()
116
 
 
117
  labels = {node: node.replace("/", "\n") for node in G.nodes()}
118
 
 
119
  nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')
120
 
 
121
  legend_elements = [
122
  Patch(facecolor='lightgreen', label='Permissive'),
123
  Patch(facecolor='lightcoral', label='Noncommercial'),
@@ -126,6 +189,28 @@ def create_family_tree(start_model):
126
  plt.legend(handles=legend_elements, loc='upper left')
127
 
128
  plt.title(f"{start_model}'s Family Tree", fontsize=20)
129
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  create_family_tree(MODEL_ID)
 
 
 
11
  from io import BytesIO
12
  from PIL import Image
13
 
14
+ TITLE = """
15
+ <div align="center">
16
+ <p style="font-size: 36px;">🌳 Model Family Tree</p>
17
+ </div><br/>
18
+ <p>Automatically calculate the <strong>family tree of a given model</strong>. It also displays the type of license each model uses (permissive, noncommercial, or unknown).</p>
19
+ <p>You can also run the code in this <a href="https://colab.research.google.com/drive/1s2eQlolcI1VGgDhqWIANfkfKvcKrMyNr?usp=sharing">Colab notebook</a>. Special thanks to <a href="https://huggingface.co/leonardlin">leonardlin</a> for his caching implementation. See also mrfakename's version in <a href="https://huggingface.co/spaces/mrfakename/merge-model-tree">this space</a>.</p>
20
+ """
21
+
22
  # Define the model ID
23
  MODEL_ID = "mlabonne/NeuralBeagle14-7B"
24
 
25
  # Define a class to cache model cards
26
  class CachedModelCard(ModelCard):
27
  _cache = {}
28
+
29
  @classmethod
30
  def load(cls, model_id: str, **kwargs) -> "ModelCard":
31
  if model_id not in cls._cache:
32
  try:
33
+ print('REQUEST ModelCard:', model_id)
34
  cls._cache[model_id] = super().load(model_id, **kwargs)
35
  except:
36
  cls._cache[model_id] = None
37
+ else:
38
+ print('CACHED:', model_id)
39
  return cls._cache[model_id]
40
 
41
  # Function to get model names from a YAML file
42
  def get_model_names_from_yaml(url):
43
+ """Get a list of parent model names from the yaml file."""
44
  model_tags = []
45
  response = requests.get(url)
46
  if response.status_code == 200:
47
  model_tags.extend([item for item in response.content if '/' in str(item)])
48
  return model_tags
49
 
50
+
51
  # Function to get the color of the model based on its license
52
  def get_license_color(model):
53
+ """Get the color of the model based on its license."""
54
  try:
55
  card = CachedModelCard.load(model)
56
  license = card.data.to_dict()['license'].lower()
57
+ # Define permissive licenses
58
+ permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail'] # Add more as needed
59
+ # Check license type
60
  if any(perm_license in license for perm_license in permissive_licenses):
61
+ return 'lightgreen' # Permissive licenses
62
  else:
63
+ return 'lightcoral' # Noncommercial or other licenses
64
  except Exception as e:
65
+ print(f"Error retrieving license for {model}: {e}")
66
  return 'lightgray'
67
 
68
+
69
  # Function to find model names in the family tree
70
  def get_model_names(model, genealogy, found_models=None, visited_models=None):
71
+ print('---')
72
+ print(model)
73
  if found_models is None:
74
  found_models = set()
75
  if visited_models is None:
76
  visited_models = set()
77
 
78
  if model in visited_models:
79
+ print("Model already visited...")
80
  return found_models
81
  visited_models.add(model)
82
 
 
108
  get_model_names(model_tag, genealogy, found_models, visited_models)
109
 
110
  except Exception as e:
111
+ print(f"Could not find model names for {model}: {e}")
112
 
113
  return found_models
114
 
115
+ def find_root_nodes(G):
116
+ """ Find all nodes in the graph with no predecessors """
117
+ return [n for n, d in G.in_degree() if d == 0]
118
+
119
+
120
+ def max_width_of_tree(G):
121
+ """ Calculate the maximum width of the tree """
122
+ max_width = 0
123
+ for root in find_root_nodes(G):
124
+ width_at_depth = calculate_width_at_depth(G, root)
125
+ local_max_width = max(width_at_depth.values())
126
+ max_width = max(max_width, local_max_width)
127
+ return max_width
128
+
129
+
130
+ def calculate_width_at_depth(G, root):
131
+ """ Calculate width at each depth starting from a given root """
132
+ depth_count = defaultdict(int)
133
+ queue = [(root, 0)]
134
+ while queue:
135
+ node, depth = queue.pop(0)
136
+ depth_count[depth] += 1
137
+ for child in G.successors(node):
138
+ queue.append((child, depth + 1))
139
+ return depth_count
140
+
141
+
142
  # Function to create the family tree
143
  def create_family_tree(start_model):
144
  genealogy = defaultdict(list)
145
+ get_model_names(start_model, genealogy) # Assuming this populates the genealogy
146
+
147
+ print("Number of models:", len(CachedModelCard._cache))
148
 
149
+ # Create a directed graph
150
  G = nx.DiGraph()
151
 
152
+ # Add nodes and edges to the graph
153
  for parent, children in genealogy.items():
154
  for child in children:
155
  G.add_edge(parent, child)
156
 
157
+ try:
158
+ # Get max depth and width
159
+ max_depth = nx.dag_longest_path_length(G) + 1
160
+ max_width = max_width_of_tree(G) + 1
161
+ except:
162
+ # Get max depth and width
163
+ max_depth = 21
164
+ max_width = 9
165
+
166
+ # Estimate plot size
167
  height = max(8, 1.6 * max_depth)
168
  width = max(8, 6 * max_width)
169
 
170
+ # Set Graphviz layout attributes for a bottom-up tree
171
  plt.figure(figsize=(width, height))
172
  pos = graphviz_layout(G, prog="dot")
173
 
174
+ # Determine node colors based on license
175
  node_colors = [get_license_color(node) for node in G.nodes()]
 
176
 
177
+ # Create a label mapping with line breaks
178
  labels = {node: node.replace("/", "\n") for node in G.nodes()}
179
 
180
+ # Draw the graph
181
  nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')
182
 
183
+ # Create a legend for the colors
184
  legend_elements = [
185
  Patch(facecolor='lightgreen', label='Permissive'),
186
  Patch(facecolor='lightcoral', label='Noncommercial'),
 
189
  plt.legend(handles=legend_elements, loc='upper left')
190
 
191
  plt.title(f"{start_model}'s Family Tree", fontsize=20)
192
+
193
+ # Capture the plot as an image in memory
194
+ img_buffer = BytesIO()
195
+ plt.savefig(img_buffer, format='png', bbox_inches='tight')
196
+ plt.close()
197
+ img_buffer.seek(0)
198
+
199
+ # Open the image using PIL
200
+ img = Image.open(img_buffer)
201
+
202
+ return img
203
+
204
+ with gr.Blocks() as demo:
205
+ gr.Markdown(TITLE)
206
+ model_id = gr.Textbox(label="Model ID", value="mlabonne/NeuralBeagle14-7B")
207
+ btn = gr.Button("Create tree")
208
+ out = gr.Image()
209
+ btn.click(fn=create_family_tree, inputs=model_id, outputs=out)
210
+
211
+ demo.queue(api_open=False).launch(show_api=False)
212
+
213
 
214
  create_family_tree(MODEL_ID)
215
+
216
+