Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import networkx as nx | |
import torch | |
import matplotlib.pyplot as plt | |
from torch_geometric.utils import from_networkx | |
from vae_model import create_vae | |
from utils import read_poems_from_directory | |
from sklearn.preprocessing import StandardScaler | |
import numpy as np | |
from individual_analyzes import analyze_sentiment # Assuming the sentiment analysis function is defined here | |
def build_poem_graph(poems, sentiment_labels): | |
poem_graph = nx.Graph() | |
poem_graph.add_nodes_from(range(len(poems))) | |
# Add edges based on similarity between poems (example: based on shared words) | |
for i in range(len(poems)): | |
for j in range(i+1, len(poems)): | |
if sentiment_labels[i] == sentiment_labels[j]: | |
poem_graph.add_edge(i, j) | |
return poem_graph | |
def visualize_poem_graph(poem_graph, sentiment_labels): | |
pos = nx.spring_layout(poem_graph) | |
colors = ['skyblue' if label == 'positive' else 'lightcoral' for label in sentiment_labels] | |
nx.draw_networkx_nodes(poem_graph, pos, node_size=200, node_color=colors) | |
nx.draw_networkx_edges(poem_graph, pos, edge_color='gray') | |
nx.draw_networkx_labels(poem_graph, pos, font_size=10) | |
plt.axis('off') | |
st.pyplot(plt) | |
def graph_guided_learning_page(): | |
st.header("Graph Guided Learning") | |
# Load and process poems | |
poems_directory = "./poems" | |
if os.path.isdir(poems_directory): | |
poems = read_poems_from_directory(poems_directory) | |
if poems: | |
# Perform sentiment analysis on the poems | |
sentiment_labels = analyze_sentiment(poems) | |
# Example feature extraction from poems | |
def extract_features(poems): | |
# Placeholder example: each poem is represented by the length of its text | |
return np.array([[len(poem)] for poem in poems]) | |
features = extract_features(poems) | |
scaler = StandardScaler() | |
scaled_features = scaler.fit_transform(features) | |
# Create VAE model and encode poems | |
input_dim = scaled_features.shape[1] | |
latent_dim = 16 | |
vae, encoder = create_vae(input_dim, latent_dim) | |
vae.fit(scaled_features, scaled_features, epochs=50, batch_size=256, validation_split=0.2) | |
latent_features = encoder.predict(scaled_features) | |
# Build a graph based on sentiment similarity | |
poem_graph = build_poem_graph(poems, sentiment_labels) | |
# Visualize the poem graph with sentiment labels | |
visualize_poem_graph(poem_graph, sentiment_labels) | |
# Convert poem graph to PyTorch Geometric data format | |
data = from_networkx(poem_graph) | |
data.x = torch.tensor(latent_features, dtype=torch.float32) | |
st.write("Latent Features:") | |
st.write(latent_features) | |
else: | |
st.warning("No poems found in the specified directory.") | |
else: | |
st.error("The specified path is not a valid directory.") |