TensorFlowClass / pages /21_GraphRag.py
eaglelandsonce's picture
Update pages/21_GraphRag.py
510db06 verified
raw
history blame
3.91 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
import graphrag # Import the graphrag library
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased")
# Initialize GraphRAG model
# Note: You may need to adjust these parameters based on GraphRAG's actual interface
graph_rag_model = graphrag.GraphRAG(
bert_model,
num_labels=2, # For binary sentiment classification
num_hidden_layers=2,
hidden_size=768,
intermediate_size=3072,
)
return tokenizer, graph_rag_model
def text_to_graph(text):
words = text.split()
G = nx.Graph()
for i, word in enumerate(words):
G.add_node(i, word=word)
if i > 0:
G.add_edge(i-1, i)
edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()],
[e[1] for e in G.edges()] + [e[0] for e in G.edges()]]
return {
"edge_index": edge_index,
"num_nodes": len(G.nodes()),
"node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature
"edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute
}
def analyze_text(text, tokenizer, model):
# Tokenize the text
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
# Create graph representation
graph = text_to_graph(text)
# Combine tokenized input with graph representation
# Note: You may need to adjust this based on GraphRAG's actual input requirements
combined_input = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"edge_index": torch.tensor(graph["edge_index"], dtype=torch.long),
"node_feat": torch.tensor(graph["node_feat"], dtype=torch.float),
"edge_attr": torch.tensor(graph["edge_attr"], dtype=torch.float),
"num_nodes": graph["num_nodes"]
}
# Perform inference
with torch.no_grad():
outputs = model(**combined_input)
# Process outputs
# Note: Adjust this based on GraphRAG's actual output format
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
probabilities = torch.softmax(logits, dim=1)
sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
return sentiment, confidence, graph
st.title("GraphRAG-based Text Analysis")
tokenizer, model = load_model()
text_input = st.text_area("Enter text for analysis:", height=200)
if st.button("Analyze Text"):
if text_input:
sentiment, confidence, graph = analyze_text(text_input, tokenizer, model)
st.write(f"Sentiment: {sentiment}")
st.write(f"Confidence: {confidence:.2f}")
# Additional analysis
word_count = len(text_input.split())
st.write(f"Word count: {word_count}")
# Most common words
words = [word.lower() for word in text_input.split() if word.isalnum()]
word_freq = Counter(words).most_common(5)
st.write("Top 5 most common words:")
for word, freq in word_freq:
st.write(f"- {word}: {freq}")
# Visualize graph
G = nx.Graph()
G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1]))
plt.figure(figsize=(10, 6))
nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
plt.title("Text as Graph")
st.pyplot(plt)
else:
st.write("Please enter some text to analyze.")