en-gin-eer's picture
Upload 4 files
8683d51
raw
history blame contribute delete
No virus
2.76 kB
import os
import re
import json
import networkx as nx
import matplotlib.pyplot as plt
# ---------------- DATA EXTRACTION ----------------
def parse_page(page_name):
"""Extract model and lora names from the given HTML page."""
with open(page_name, 'r', encoding='utf-8') as f:
html_content = f.read()
return extract_data_from_html(html_content)
def extract_data_from_html(html_content):
"""Extract specific data from the embedded JSON within the HTML content."""
pattern = r'id="__NEXT_DATA__" type="application/json">(.*?)</script><script defer'
match = re.search(pattern, html_content)
if not match:
return None
json_string = match.group(1)
data_dict = json.loads(json_string)
return get_model_and_resources(data_dict)
def get_model_and_resources(data_dict):
"""Retrieve model and associated resources from the parsed JSON data."""
model_name = data_dict['props']['pageProps']['trpcState']['json']['queries'][0]['state']['data']['meta']['Model']
lora_names = [r['name'] for r in data_dict['props']['pageProps']['trpcState']['json']['queries'][0]['state']['data']['meta']['resources'] if r['type'] == 'lora']
return model_name, lora_names
# ---------------- GRAPH CONSTRUCTION ----------------
def build_graph(data, degree_threshold):
"""Build a bipartite graph from the data and prune nodes with degrees below the threshold."""
B = nx.Graph()
for page_name, (model, loras) in data.items():
B.add_node(model, bipartite=0)
for lora in loras:
B.add_node(lora, bipartite=1)
B.add_edge(model, lora, page=page_name.split('.')[0])
nodes_to_remove = [node for node, degree in dict(B.degree()).items() if degree < degree_threshold]
B.remove_nodes_from(nodes_to_remove)
return B
# ---------------- VISUALIZATION AND ANALYSIS ----------------
def visualize_bipartite(B):
"""Visualize the bipartite graph."""
model_nodes = {n for n, d in B.nodes(data=True) if d['bipartite']==0}
lora_nodes = set(B) - model_nodes
pos = nx.bipartite_layout(B, model_nodes)
plt.figure(figsize=(10, 5))
nx.draw(B, pos, with_labels=True, node_color=['#1f78b4' if node in model_nodes else '#33a02c' for node in B.nodes()])
plt.title("Bipartite Graph between Model Name and Lora Name")
plt.show()
def most_connected_models(B, top_n=10):
"""List the most connected models in the bipartite graph."""
model_nodes = {n for n, d in B.nodes(data=True) if d['bipartite']==0}
sorted_models = sorted(model_nodes, key=lambda x: B.degree(x), reverse=True)
for model in sorted_models[:top_n]:
loras = list(B.neighbors(model))
print(f"Model: {model}, Connected Loras: {loras}")