gmedin's picture
Update app.py
d8ff0e1 verified
import streamlit as st
import networkx as nx
from pyvis.network import Network
import pickle
import math
import random
import requests
import os
from huggingface_hub import hf_hub_download
# Dictionary to map brands to their respective HuggingFace model repo files
BRAND_GRAPHS = {
'drumeo': 'drumeo_graph.pkl',
'pianote': 'pianote_graph.pkl',
'singeo': 'singeo_graph.pkl',
'guitareo': 'guitareo_graph.pkl'
}
# HuggingFace Repository Info
#HF_REPO = "MusoraProductDepartment/popular-path-graphs"
AUTH_TOKEN = os.getenv('HF_TOKEN')
API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"
@st.cache_resource
def load_graph_from_hf(brand):
"""
Load the graph for the selected brand from HuggingFace Hub.
"""
try:
# Download the file from HuggingFace Hub
HF_REPO = f'MusoraProductDepartment/{brand}-graph'
cache_dir = '/tmp'
file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], token=AUTH_TOKEN, cache_dir=cache_dir, repo_type='model')
# Load the graph
with open(file_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
st.error(f"Error loading graph from HuggingFace: {e}")
return None
def filter_graph(graph, node_threshold=10, edge_threshold=5):
"""
Filters the graph to include only popular nodes and edges.
"""
popular_nodes = [
node for node in graph.nodes
if graph.degree(node) >= node_threshold
]
filtered_graph = graph.subgraph(popular_nodes).copy()
for u, v, data in list(filtered_graph.edges(data=True)):
if data.get("weight", 0) < edge_threshold:
filtered_graph.remove_edge(u, v)
return filtered_graph
def get_rankings_from_api(brand, user_id, content_ids):
"""
Call the rank_items API to fetch rankings for the given user and content IDs.
"""
try:
payload = {
"brand": brand.upper(),
"user_id": int(user_id),
"content_ids": [int(content_id) for content_id in content_ids]
}
headers = {
"Authorization": f"Bearer {AUTH_TOKEN}",
"accept": "application/json",
"Content-Type": "application/json"
}
response = requests.post(API_URL, json=payload, headers=headers)
response.raise_for_status()
rankings = response.json()
return rankings
except Exception as e:
st.error(f"Error calling rank_items API: {e}")
return {}
def rank_to_color(rank, max_rank):
"""
Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
and light gray indicates low relevance (high rank).
"""
if rank > max_rank: # Handle items without ranking
return "#E8E8E8" # Very light gray for unranked items
intensity = int(55 + (rank / max_rank) * 200) # Darker for lower ranks
return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale
def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5, show_titles=False, rankings=None):
net = Network(notebook=False, width="100%", height="600px", directed=True)
net.set_options("""
var options = {
"physics": {
"barnesHut": {
"gravitationalConstant": -15000,
"centralGravity": 0.8
}
}
}
""")
visited_nodes = set()
added_edges = set()
current_nodes = [str(start_node)]
max_rank = len(rankings) if rankings else 0
# Add the starting node, color it red, and include a tooltip
start_title = graph.nodes[str(start_node)].get('title', 'No title available')
start_in_degree = graph.in_degree(str(start_node))
start_out_degree = graph.out_degree(str(start_node))
start_node_size = (start_in_degree + start_out_degree) * 0.15
start_rank = rankings.index(str(start_node)) if rankings and str(start_node) in rankings else max_rank + 1
if rankings:
start_border_color = rank_to_color(start_rank, max_rank)
else:
start_border_color = 'darkblue'
label = str(start_node) if not show_titles else f"{str(start_node)}: {start_title[:15]}..."
net.add_node(
str(start_node),
label=label,
color={"background": "darkblue", "border": start_border_color},
title=f"{start_title}, In-degree: {start_in_degree}, Out-degree: {start_out_degree}, Rank: {start_rank}",
size=start_node_size,
borderWidth=3,
borderWidthSelected=6
)
visited_nodes.add(str(start_node))
for layer in range(layers):
next_nodes = []
for node in current_nodes:
neighbors = sorted(
[(str(neighbor), data['weight']) for neighbor, data in graph[node].items()],
key=lambda x: x[1],
reverse=True
)[:top_k]
for neighbor, weight in neighbors:
if neighbor not in visited_nodes:
neighbor_title = graph.nodes[neighbor].get('title', 'No title available')
neighbor_in_degree = graph.in_degree(neighbor)
neighbor_out_degree = graph.out_degree(neighbor)
neighbor_size = (neighbor_in_degree + neighbor_out_degree) * 0.15
neighbor_rank = rankings.index(neighbor) if rankings and neighbor in rankings else max_rank + 1
node_color = 'red' if neighbor_in_degree > neighbor_out_degree * 1.5 else \
'green' if neighbor_out_degree > neighbor_in_degree * 1.5 else 'lightblue'
if rankings:
neighbor_border_color = rank_to_color(neighbor_rank, max_rank)
else:
neighbor_border_color = node_color
label = str(neighbor) if not show_titles else f"{str(neighbor)}: {neighbor_title[:15]}..."
net.add_node(
neighbor,
label=label,
title=f"{neighbor_title}, In-degree: {neighbor_in_degree}, Out-degree: {neighbor_out_degree}, Rank: {neighbor_rank}",
size=neighbor_size,
color={"background": node_color, "border": neighbor_border_color},
borderWidth=3,
borderWidthSelected=6
)
edge = (node, neighbor)
if edge not in added_edges:
edge_width = math.log(weight + 1) * 8
net.add_edge(node, neighbor, label=f"w:{weight}", width=edge_width, color='lightgray')
added_edges.add(edge)
visited_nodes.add(neighbor)
next_nodes.append(neighbor)
current_nodes = next_nodes
html_content = net.generate_html()
st.components.v1.html(html_content, height=600, scrolling=False)
st.title("Popular Path Expansion + Personalization")
# Brand Selection
selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys()))
if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
st.session_state.selected_brand = selected_brand
G = load_graph_from_hf(selected_brand)
# Sort nodes by popularity (in-degree + out-degree) and select from top 20
popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
st.session_state.start_node = random.choice(top_20_nodes)
else:
G = load_graph_from_hf(selected_brand)
# Random Selection Button
if st.button("Random Selection"):
st.session_state.start_node = random.choice(list(G.nodes))
start_node = st.text_input(
"Enter the starting node ID:",
value=str(st.session_state.start_node)
)
try:
start_node = str(start_node)
except ValueError:
st.error("Please enter a valid numeric content ID.")
st.stop()
# Input: Student ID
student_id = st.text_input("Enter a student ID (optional):", value="")
# Toggle for showing content titles
show_titles = st.checkbox("Show content titles", value=False)
# Filter the graph
node_degree_threshold = 1
edge_weight_threshold = 1
G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold)
# Fetch rankings if student ID is provided
rankings = {}
if student_id:
content_ids = list(G_filtered.nodes)
rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
if rankings:
rankings = rankings['ranked_content_ids']
layers = st.slider("Depth to explore:", 1, 6, value=3)
top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
if st.button("Expand Graph"):
if start_node in G_filtered:
dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k, show_titles=show_titles, rankings=rankings)
else:
st.error("The starting node is not in the graph!")