Spaces:
Runtime error
Runtime error
| # script_analysis.py | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import json | |
| import streamlit as st | |
| from utils import client | |
| import plotly.graph_objs as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| def analyze_script(thread_id, additional_context=None): | |
| run = client.beta.threads.runs.create( | |
| thread_id=thread_id, | |
| assistant_id="asst_0TVOqfDUPuaSxtea11xa7DB0" | |
| ) | |
| while run.status in ['queued', 'in_progress', 'cancelling']: | |
| run = client.beta.threads.runs.retrieve( | |
| thread_id=thread_id, | |
| run_id=run.id | |
| ) | |
| if run.status == 'completed': | |
| messages = client.beta.threads.messages.list(thread_id=thread_id) | |
| analysis = next((msg.content[0].text.value for msg in reversed(list(messages)) if msg.role == "assistant"), "") | |
| return analysis | |
| else: | |
| return f"Error: Run status is {run.status}" | |
| def process_script_analysis(analysis): | |
| try: | |
| # Print raw data for debugging | |
| st.write("Raw data:") | |
| st.write(analysis) | |
| # Parse JSON data | |
| data = json.loads(analysis) | |
| # Create a list to hold all script elements | |
| script_elements = [] | |
| # Define the stages and attributes we're interested in | |
| stages = ["Introduction", "Rising Action", "Midpoint", "Complications", "Climax", "Falling Action", "Resolution"] | |
| attributes = ["intensity", "narrative_intensity", "pacing", "tension", "emotion", "action"] | |
| # Iterate through all stages | |
| for stage in stages: | |
| if stage in data: | |
| element = data[stage] | |
| elif stage.replace(" ", "_") in data: # Check for underscore version | |
| element = data[stage.replace(" ", "_")] | |
| else: | |
| # If stage is missing, create a placeholder with default values | |
| element = {attr: 0 for attr in attributes} | |
| element['stage'] = stage | |
| script_elements.append(element) | |
| # Create DataFrame | |
| df = pd.DataFrame(script_elements) | |
| # Check if we have data for all stages and attributes | |
| missing_stages = set(stages) - set(df['stage']) | |
| missing_attributes = set(attributes) - set(df.columns) | |
| if missing_stages: | |
| st.warning(f"Missing data for stages: {', '.join(missing_stages)}") | |
| if missing_attributes: | |
| st.warning(f"Missing data for attributes: {', '.join(missing_attributes)}") | |
| # Ensure all required columns are present | |
| for attr in attributes: | |
| if attr not in df.columns: | |
| df[attr] = 0 # or some default value | |
| # Display the DataFrame | |
| st.write("### Processed data:") | |
| st.dataframe(df) | |
| # Create an interactive line chart for all attributes | |
| st.write("### Script Attributes Across Stages") | |
| fig = go.Figure() | |
| for attr in attributes: | |
| fig.add_trace(go.Scatter(x=df['stage'], y=df[attr], mode='lines+markers', name=attr.capitalize())) | |
| fig.update_layout(title='Script Attributes Across Stages', xaxis_title='Stage', yaxis_title='Score') | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Create an interactive heatmap | |
| st.write("### Heatmap of Script Attributes") | |
| heatmap_data = df.set_index('stage')[attributes] | |
| fig = px.imshow(heatmap_data, | |
| labels=dict(x="Attributes", y="Stages", color="Score"), | |
| x=attributes, | |
| y=heatmap_data.index, | |
| color_continuous_scale="YlOrRd") | |
| fig.update_layout(title='Script Attributes Heatmap') | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Create interactive radar charts for each stage | |
| st.write("### Radar Charts for Each Stage") | |
| for _, row in df.iterrows(): | |
| stage = row['stage'] | |
| values = row[attributes].values | |
| fig = go.Figure(data=go.Scatterpolar( | |
| r=values, | |
| theta=attributes, | |
| fill='toself' | |
| )) | |
| fig.update_layout( | |
| polar=dict(radialaxis=dict(visible=True, range=[0, 1])), | |
| showlegend=False, | |
| title=f"Attributes for {stage}" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Create a stacked bar chart to compare stages | |
| st.write("### Stage Comparison") | |
| fig = go.Figure() | |
| for attr in attributes: | |
| fig.add_trace(go.Bar(x=df['stage'], y=df[attr], name=attr.capitalize())) | |
| fig.update_layout(barmode='stack', title='Attribute Composition by Stage', | |
| xaxis_title='Stage', yaxis_title='Cumulative Score') | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Create a parallel coordinates plot | |
| st.write("### Parallel Coordinates Plot") | |
| # Create a numeric color scale based on the order of stages | |
| color_scale = list(range(len(df))) | |
| fig = px.parallel_coordinates(df, color=color_scale, | |
| dimensions=['intensity', 'narrative_intensity', 'pacing', 'tension', 'emotion', 'action'], | |
| color_continuous_scale=px.colors.sequential.Viridis, | |
| color_continuous_midpoint=len(df) // 2) | |
| # Update color axis to show stage names instead of numbers | |
| fig.update_layout( | |
| coloraxis_colorbar=dict( | |
| title="Stage", | |
| tickvals=color_scale, | |
| ticktext=df['stage'], | |
| lenmode="pixels", len=300, | |
| ) | |
| ) | |
| fig.update_layout(title='Parallel Coordinates Plot of Script Attributes') | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Additional analysis or insights | |
| st.write("### Key Insights") | |
| st.write("Based on the analysis, here are some key insights about the script:") | |
| highest_intensity = df.loc[df['intensity'].idxmax(), 'stage'] | |
| st.write(f"- The highest intensity occurs during the {highest_intensity} stage.") | |
| avg_pacing = df['pacing'].mean() | |
| st.write(f"- The average pacing of the script is {avg_pacing:.2f} out of 1.") | |
| emotion_variance = df['emotion'].var() | |
| st.write(f"- The emotional variance throughout the script is {emotion_variance:.2f}, indicating {'a highly varied' if emotion_variance > 0.1 else 'a consistent'} emotional journey.") | |
| except Exception as e: | |
| st.error(f"Error processing data for Script Analysis: {e}") | |
| st.write("Please check the structure of the JSON data:") | |
| st.json(analysis) |