|
import os |
|
import re |
|
import json |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
def parse_page(page_name): |
|
"""Extract model and lora names from the given HTML page.""" |
|
with open(page_name, 'r', encoding='utf-8') as f: |
|
html_content = f.read() |
|
return extract_data_from_html(html_content) |
|
|
|
def extract_data_from_html(html_content): |
|
"""Extract specific data from the embedded JSON within the HTML content.""" |
|
pattern = r'id="__NEXT_DATA__" type="application/json">(.*?)</script><script defer' |
|
match = re.search(pattern, html_content) |
|
if not match: |
|
return None |
|
|
|
json_string = match.group(1) |
|
data_dict = json.loads(json_string) |
|
return get_model_and_resources(data_dict) |
|
|
|
def get_model_and_resources(data_dict): |
|
"""Retrieve model and associated resources from the parsed JSON data.""" |
|
model_name = data_dict['props']['pageProps']['trpcState']['json']['queries'][0]['state']['data']['meta']['Model'] |
|
lora_names = [r['name'] for r in data_dict['props']['pageProps']['trpcState']['json']['queries'][0]['state']['data']['meta']['resources'] if r['type'] == 'lora'] |
|
return model_name, lora_names |
|
|
|
|
|
|
|
def build_graph(data, degree_threshold): |
|
"""Build a bipartite graph from the data and prune nodes with degrees below the threshold.""" |
|
B = nx.Graph() |
|
|
|
for page_name, (model, loras) in data.items(): |
|
B.add_node(model, bipartite=0) |
|
for lora in loras: |
|
B.add_node(lora, bipartite=1) |
|
B.add_edge(model, lora, page=page_name.split('.')[0]) |
|
|
|
nodes_to_remove = [node for node, degree in dict(B.degree()).items() if degree < degree_threshold] |
|
B.remove_nodes_from(nodes_to_remove) |
|
return B |
|
|
|
|
|
|
|
|
|
def visualize_bipartite(B): |
|
"""Visualize the bipartite graph.""" |
|
model_nodes = {n for n, d in B.nodes(data=True) if d['bipartite']==0} |
|
lora_nodes = set(B) - model_nodes |
|
pos = nx.bipartite_layout(B, model_nodes) |
|
|
|
plt.figure(figsize=(10, 5)) |
|
nx.draw(B, pos, with_labels=True, node_color=['#1f78b4' if node in model_nodes else '#33a02c' for node in B.nodes()]) |
|
plt.title("Bipartite Graph between Model Name and Lora Name") |
|
plt.show() |
|
|
|
def most_connected_models(B, top_n=10): |
|
"""List the most connected models in the bipartite graph.""" |
|
model_nodes = {n for n, d in B.nodes(data=True) if d['bipartite']==0} |
|
sorted_models = sorted(model_nodes, key=lambda x: B.degree(x), reverse=True) |
|
|
|
for model in sorted_models[:top_n]: |
|
loras = list(B.neighbors(model)) |
|
print(f"Model: {model}, Connected Loras: {loras}") |
|
|