| | import streamlit as st |
| | import networkx as nx |
| | from pyvis.network import Network |
| | import pickle |
| | import math |
| | import random |
| | import requests |
| | import os |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | BRAND_GRAPHS = { |
| | 'drumeo': 'drumeo_graph.pkl', |
| | 'pianote': 'pianote_graph.pkl', |
| | 'singeo': 'singeo_graph.pkl', |
| | 'guitareo': 'guitareo_graph.pkl' |
| | } |
| |
|
| | |
| | |
| | AUTH_TOKEN = os.getenv('HF_TOKEN') |
| | API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/" |
| |
|
| |
|
| | @st.cache_resource |
| | def load_graph_from_hf(brand): |
| | """ |
| | Load the graph for the selected brand from HuggingFace Hub. |
| | """ |
| | try: |
| | |
| | HF_REPO = f'MusoraProductDepartment/{brand}-graph' |
| | cache_dir = '/tmp' |
| | file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], token=AUTH_TOKEN, cache_dir=cache_dir, repo_type='model') |
| | |
| | with open(file_path, 'rb') as f: |
| | return pickle.load(f) |
| | except Exception as e: |
| | st.error(f"Error loading graph from HuggingFace: {e}") |
| | return None |
| |
|
| |
|
| | def filter_graph(graph, node_threshold=10, edge_threshold=5): |
| | """ |
| | Filters the graph to include only popular nodes and edges. |
| | """ |
| | popular_nodes = [ |
| | node for node in graph.nodes |
| | if graph.degree(node) >= node_threshold |
| | ] |
| |
|
| | filtered_graph = graph.subgraph(popular_nodes).copy() |
| |
|
| | for u, v, data in list(filtered_graph.edges(data=True)): |
| | if data.get("weight", 0) < edge_threshold: |
| | filtered_graph.remove_edge(u, v) |
| |
|
| | return filtered_graph |
| |
|
| |
|
| | def get_rankings_from_api(brand, user_id, content_ids): |
| | """ |
| | Call the rank_items API to fetch rankings for the given user and content IDs. |
| | """ |
| | try: |
| | payload = { |
| | "brand": brand.upper(), |
| | "user_id": int(user_id), |
| | "content_ids": [int(content_id) for content_id in content_ids] |
| | } |
| | headers = { |
| | "Authorization": f"Bearer {AUTH_TOKEN}", |
| | "accept": "application/json", |
| | "Content-Type": "application/json" |
| | } |
| | response = requests.post(API_URL, json=payload, headers=headers) |
| | response.raise_for_status() |
| | rankings = response.json() |
| | return rankings |
| | except Exception as e: |
| | st.error(f"Error calling rank_items API: {e}") |
| | return {} |
| |
|
| |
|
| | def rank_to_color(rank, max_rank): |
| | """ |
| | Map a rank to a grayscale color, where dark gray indicates high relevance (low rank), |
| | and light gray indicates low relevance (high rank). |
| | """ |
| | if rank > max_rank: |
| | return "#E8E8E8" |
| | intensity = int(55 + (rank / max_rank) * 200) |
| | return f"rgb({intensity}, {intensity}, {intensity})" |
| |
|
| |
|
| | def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5, show_titles=False, rankings=None): |
| | net = Network(notebook=False, width="100%", height="600px", directed=True) |
| | net.set_options(""" |
| | var options = { |
| | "physics": { |
| | "barnesHut": { |
| | "gravitationalConstant": -15000, |
| | "centralGravity": 0.8 |
| | } |
| | } |
| | } |
| | """) |
| |
|
| | visited_nodes = set() |
| | added_edges = set() |
| | current_nodes = [str(start_node)] |
| |
|
| | max_rank = len(rankings) if rankings else 0 |
| |
|
| | |
| | start_title = graph.nodes[str(start_node)].get('title', 'No title available') |
| | start_in_degree = graph.in_degree(str(start_node)) |
| | start_out_degree = graph.out_degree(str(start_node)) |
| | start_node_size = (start_in_degree + start_out_degree) * 0.15 |
| | start_rank = rankings.index(str(start_node)) if rankings and str(start_node) in rankings else max_rank + 1 |
| | if rankings: |
| | start_border_color = rank_to_color(start_rank, max_rank) |
| | else: |
| | start_border_color = 'darkblue' |
| | label = str(start_node) if not show_titles else f"{str(start_node)}: {start_title[:15]}..." |
| | net.add_node( |
| | str(start_node), |
| | label=label, |
| | color={"background": "darkblue", "border": start_border_color}, |
| | title=f"{start_title}, In-degree: {start_in_degree}, Out-degree: {start_out_degree}, Rank: {start_rank}", |
| | size=start_node_size, |
| | borderWidth=3, |
| | borderWidthSelected=6 |
| | ) |
| | visited_nodes.add(str(start_node)) |
| |
|
| | for layer in range(layers): |
| | next_nodes = [] |
| | for node in current_nodes: |
| | neighbors = sorted( |
| | [(str(neighbor), data['weight']) for neighbor, data in graph[node].items()], |
| | key=lambda x: x[1], |
| | reverse=True |
| | )[:top_k] |
| |
|
| | for neighbor, weight in neighbors: |
| | if neighbor not in visited_nodes: |
| | neighbor_title = graph.nodes[neighbor].get('title', 'No title available') |
| | neighbor_in_degree = graph.in_degree(neighbor) |
| | neighbor_out_degree = graph.out_degree(neighbor) |
| | neighbor_size = (neighbor_in_degree + neighbor_out_degree) * 0.15 |
| | neighbor_rank = rankings.index(neighbor) if rankings and neighbor in rankings else max_rank + 1 |
| |
|
| | node_color = 'red' if neighbor_in_degree > neighbor_out_degree * 1.5 else \ |
| | 'green' if neighbor_out_degree > neighbor_in_degree * 1.5 else 'lightblue' |
| | if rankings: |
| | neighbor_border_color = rank_to_color(neighbor_rank, max_rank) |
| | else: |
| | neighbor_border_color = node_color |
| |
|
| | label = str(neighbor) if not show_titles else f"{str(neighbor)}: {neighbor_title[:15]}..." |
| | net.add_node( |
| | neighbor, |
| | label=label, |
| | title=f"{neighbor_title}, In-degree: {neighbor_in_degree}, Out-degree: {neighbor_out_degree}, Rank: {neighbor_rank}", |
| | size=neighbor_size, |
| | color={"background": node_color, "border": neighbor_border_color}, |
| | borderWidth=3, |
| | borderWidthSelected=6 |
| | ) |
| | edge = (node, neighbor) |
| | if edge not in added_edges: |
| | edge_width = math.log(weight + 1) * 8 |
| | net.add_edge(node, neighbor, label=f"w:{weight}", width=edge_width, color='lightgray') |
| | added_edges.add(edge) |
| | visited_nodes.add(neighbor) |
| | next_nodes.append(neighbor) |
| |
|
| | current_nodes = next_nodes |
| |
|
| | html_content = net.generate_html() |
| | st.components.v1.html(html_content, height=600, scrolling=False) |
| |
|
| |
|
| | st.title("Popular Path Expansion + Personalization") |
| |
|
| | |
| | selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys())) |
| |
|
| | if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand: |
| | st.session_state.selected_brand = selected_brand |
| | G = load_graph_from_hf(selected_brand) |
| |
|
| | |
| | popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True) |
| | top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes |
| | st.session_state.start_node = random.choice(top_20_nodes) |
| | else: |
| | G = load_graph_from_hf(selected_brand) |
| |
|
| | |
| | if st.button("Random Selection"): |
| | st.session_state.start_node = random.choice(list(G.nodes)) |
| |
|
| | start_node = st.text_input( |
| | "Enter the starting node ID:", |
| | value=str(st.session_state.start_node) |
| | ) |
| |
|
| | try: |
| | start_node = str(start_node) |
| | except ValueError: |
| | st.error("Please enter a valid numeric content ID.") |
| | st.stop() |
| |
|
| |
|
| | |
| | student_id = st.text_input("Enter a student ID (optional):", value="") |
| |
|
| | |
| | show_titles = st.checkbox("Show content titles", value=False) |
| |
|
| | |
| | node_degree_threshold = 1 |
| | edge_weight_threshold = 1 |
| | G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold) |
| |
|
| | |
| | rankings = {} |
| | if student_id: |
| | content_ids = list(G_filtered.nodes) |
| | rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids) |
| | if rankings: |
| | rankings = rankings['ranked_content_ids'] |
| |
|
| | layers = st.slider("Depth to explore:", 1, 6, value=3) |
| | top_k = st.slider("Branching factor (per node):", 1, 6, value=3) |
| |
|
| | if st.button("Expand Graph"): |
| | if start_node in G_filtered: |
| | dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k, show_titles=show_titles, rankings=rankings) |
| | else: |
| | st.error("The starting node is not in the graph!") |
| |
|