CinAI-ScriptSmart / script_analysis.py
Ashish1722's picture
Upload 23 files
4f038ca verified
Raw
History Blame Contribute Delete
7.04 kB
# script_analysis.py
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json
import streamlit as st
from utils import client
import plotly.graph_objs as go
import plotly.express as px
from plotly.subplots import make_subplots
def analyze_script(thread_id, additional_context=None):
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id="asst_0TVOqfDUPuaSxtea11xa7DB0"
)
while run.status in ['queued', 'in_progress', 'cancelling']:
run = client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run.id
)
if run.status == 'completed':
messages = client.beta.threads.messages.list(thread_id=thread_id)
analysis = next((msg.content[0].text.value for msg in reversed(list(messages)) if msg.role == "assistant"), "")
return analysis
else:
return f"Error: Run status is {run.status}"
def process_script_analysis(analysis):
try:
# Print raw data for debugging
st.write("Raw data:")
st.write(analysis)
# Parse JSON data
data = json.loads(analysis)
# Create a list to hold all script elements
script_elements = []
# Define the stages and attributes we're interested in
stages = ["Introduction", "Rising Action", "Midpoint", "Complications", "Climax", "Falling Action", "Resolution"]
attributes = ["intensity", "narrative_intensity", "pacing", "tension", "emotion", "action"]
# Iterate through all stages
for stage in stages:
if stage in data:
element = data[stage]
elif stage.replace(" ", "_") in data: # Check for underscore version
element = data[stage.replace(" ", "_")]
else:
# If stage is missing, create a placeholder with default values
element = {attr: 0 for attr in attributes}
element['stage'] = stage
script_elements.append(element)
# Create DataFrame
df = pd.DataFrame(script_elements)
# Check if we have data for all stages and attributes
missing_stages = set(stages) - set(df['stage'])
missing_attributes = set(attributes) - set(df.columns)
if missing_stages:
st.warning(f"Missing data for stages: {', '.join(missing_stages)}")
if missing_attributes:
st.warning(f"Missing data for attributes: {', '.join(missing_attributes)}")
# Ensure all required columns are present
for attr in attributes:
if attr not in df.columns:
df[attr] = 0 # or some default value
# Display the DataFrame
st.write("### Processed data:")
st.dataframe(df)
# Create an interactive line chart for all attributes
st.write("### Script Attributes Across Stages")
fig = go.Figure()
for attr in attributes:
fig.add_trace(go.Scatter(x=df['stage'], y=df[attr], mode='lines+markers', name=attr.capitalize()))
fig.update_layout(title='Script Attributes Across Stages', xaxis_title='Stage', yaxis_title='Score')
st.plotly_chart(fig, use_container_width=True)
# Create an interactive heatmap
st.write("### Heatmap of Script Attributes")
heatmap_data = df.set_index('stage')[attributes]
fig = px.imshow(heatmap_data,
labels=dict(x="Attributes", y="Stages", color="Score"),
x=attributes,
y=heatmap_data.index,
color_continuous_scale="YlOrRd")
fig.update_layout(title='Script Attributes Heatmap')
st.plotly_chart(fig, use_container_width=True)
# Create interactive radar charts for each stage
st.write("### Radar Charts for Each Stage")
for _, row in df.iterrows():
stage = row['stage']
values = row[attributes].values
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=attributes,
fill='toself'
))
fig.update_layout(
polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
showlegend=False,
title=f"Attributes for {stage}"
)
st.plotly_chart(fig, use_container_width=True)
# Create a stacked bar chart to compare stages
st.write("### Stage Comparison")
fig = go.Figure()
for attr in attributes:
fig.add_trace(go.Bar(x=df['stage'], y=df[attr], name=attr.capitalize()))
fig.update_layout(barmode='stack', title='Attribute Composition by Stage',
xaxis_title='Stage', yaxis_title='Cumulative Score')
st.plotly_chart(fig, use_container_width=True)
# Create a parallel coordinates plot
st.write("### Parallel Coordinates Plot")
# Create a numeric color scale based on the order of stages
color_scale = list(range(len(df)))
fig = px.parallel_coordinates(df, color=color_scale,
dimensions=['intensity', 'narrative_intensity', 'pacing', 'tension', 'emotion', 'action'],
color_continuous_scale=px.colors.sequential.Viridis,
color_continuous_midpoint=len(df) // 2)
# Update color axis to show stage names instead of numbers
fig.update_layout(
coloraxis_colorbar=dict(
title="Stage",
tickvals=color_scale,
ticktext=df['stage'],
lenmode="pixels", len=300,
)
)
fig.update_layout(title='Parallel Coordinates Plot of Script Attributes')
st.plotly_chart(fig, use_container_width=True)
# Additional analysis or insights
st.write("### Key Insights")
st.write("Based on the analysis, here are some key insights about the script:")
highest_intensity = df.loc[df['intensity'].idxmax(), 'stage']
st.write(f"- The highest intensity occurs during the {highest_intensity} stage.")
avg_pacing = df['pacing'].mean()
st.write(f"- The average pacing of the script is {avg_pacing:.2f} out of 1.")
emotion_variance = df['emotion'].var()
st.write(f"- The emotional variance throughout the script is {emotion_variance:.2f}, indicating {'a highly varied' if emotion_variance > 0.1 else 'a consistent'} emotional journey.")
except Exception as e:
st.error(f"Error processing data for Script Analysis: {e}")
st.write("Please check the structure of the JSON data:")
st.json(analysis)