en-gin-eer's picture
Update app.py
38f30e6
raw
history blame contribute delete
No virus
3.62 kB
import streamlit as st
import pandas as pd
import networkx as nx
import os
import pickle
import tqdm
from analysis import build_graph, parse_page
def clean(results):
new = {}
for k in results:
if results[k] and len(results[k]) > 0:
new[k] = results[k]
return new
# Your existing functions here...
if "B_degree_threshold" not in st.session_state:
st.session_state.B_degree_threshold = 10
if "B" not in st.session_state:
if not os.path.exists('data.pkl'):
page_folder = 'pages'
pages = os.listdir(page_folder)
results = {}
for p in tqdm.tqdm(pages):
try:
results[p] = parse_page(os.path.join(page_folder, p))
except Exception as e:
pass
with open('data.pkl', 'wb') as f:
pickle.dump(results, f)
else:
with open('data.pkl', 'rb') as f:
results = pickle.load(f)
st.session_state.results = clean(results)
st.session_state.B = build_graph(st.session_state.results, st.session_state.B_degree_threshold)
# Streamlit app
def main():
st.title("SD BaseModel Lora Connections")
# Sidebar for degree_threshold
B_degree_threshold = st.sidebar.slider("Select Degree Threshold", 1, 100, 10)
# Build the graph
if B_degree_threshold != st.session_state.B_degree_threshold:
st.session_state.B_degree_threshold = B_degree_threshold
st.session_state.B = build_graph(st.session_state.results, B_degree_threshold)
st.sidebar.write(f"There are {len(st.session_state.B)} nodes analyzed.")
# Filter out model nodes and lora nodes
model_nodes = {n for n, d in st.session_state.B.nodes(data=True) if d['bipartite']==0}
lora_nodes = set(st.session_state.B) - model_nodes
# Sort model nodes and lora nodes based on their degree
sorted_models = sorted(model_nodes, key=lambda x: st.session_state.B.degree(x), reverse=True)
sorted_loras = sorted(lora_nodes, key=lambda x: st.session_state.B.degree(x), reverse=True)
# Model selection
selected_model = st.selectbox("Select Model (sorted by degree)", sorted_models)
if selected_model:
loras_for_model = list(st.session_state.B.neighbors(selected_model))
page_names_for_model = [st.session_state.B[selected_model][lora]['page'] for lora in loras_for_model]
page_names_for_model = ['https://civitai.com/images/'+page for page in page_names_for_model]
# Convert DataFrame to HTML with clickable links
df = pd.DataFrame({"Lora Names": loras_for_model, "Image Link": page_names_for_model})
df["Image Link"] = df["Image Link"].apply(lambda x: f'<a href="{x}" target="_blank">{x}</a>')
st.markdown(df.to_html(escape=False, index=False), unsafe_allow_html=True)
# Lora selection
selected_lora = st.selectbox("Select Lora (sorted by degree)", sorted_loras)
if selected_lora:
models_for_lora = list(st.session_state.B.neighbors(selected_lora))
page_names_for_lora = [st.session_state.B[model][selected_lora]['page'] for model in models_for_lora]
page_names_for_lora = ['https://civitai.com/images/'+page for page in page_names_for_lora]
# Convert DataFrame to HTML with clickable links
df = pd.DataFrame({"Model Names": models_for_lora, "Image Link": page_names_for_lora})
df["Image Link"] = df["Image Link"].apply(lambda x: f'<a href="{x}" target="_blank">{x}</a>')
st.markdown(df.to_html(escape=False, index=False), unsafe_allow_html=True)
if __name__ == "__main__":
main()