import time
import plotly.graph_objects as go
from datetime import datetime, timedelta
SAMPLING_RATE = 16_000
COLOR_MAP = {
"Neutralità": "rgb(178, 178, 178)",
"Rabbia": "rgb(160, 61, 62)",
"Paura": "rgb(91, 57, 136)",
"Gioia": "rgb(255, 255, 0)",
"Sorpresa": "rgb(60, 175, 175)",
"Tristezza": "rgb(64, 106, 173)",
"Disgusto": "rgb(100, 153, 65)",
}
def create_behaviour_gantt_plot(behaviour_chunks, confidence_threshold=60):
print("Creating behaviour Gantt plot...")
emotion_order = [
"Gioia",
"Sorpresa",
"Disgusto",
"Tristezza",
"Paura",
"Rabbia",
"Neutralità"
]
fig = go.Figure()
chunk_starts = [start/SAMPLING_RATE for start, _, _, _, _ in behaviour_chunks]
chunk_ends = [end/SAMPLING_RATE for _, end, _, _, _ in behaviour_chunks]
# Create reference time for plotting (starting at 0)
# We'll use a base datetime and add seconds
base_time = datetime(2_000, 1, 1, 0, 0, 0) # TODO: change magic numbers
start_times = [base_time + timedelta(seconds=t) for t in chunk_starts]
end_times = [base_time + timedelta(seconds=t) for t in chunk_ends]
# Calculate midpoints for each chunk (for trend line)
mid_times = [base_time + timedelta(seconds=(s+e)/2) for s, e in zip(chunk_starts, chunk_ends)]
heights = [height * 100 for _, _, _, height, _ in behaviour_chunks]
emotions = [emotion for _, _, _, _, emotion in behaviour_chunks]
hover_texts = []
for i, (start, end, label, height, emotion) in enumerate(behaviour_chunks):
start_fmt = time.strftime('%H:%M:%S', time.gmtime(start / SAMPLING_RATE))
end_fmt = time.strftime('%H:%M:%S', time.gmtime(end / SAMPLING_RATE))
duration_seconds = (end - start) / SAMPLING_RATE
duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_seconds))
hover_text = f"Inizio: {start_fmt}
Fine: {end_fmt}
Durata: {duration_str}
Testo: {label}
Attendibilità: {height*100:.2f}%
Emozione: {emotion}"
hover_texts.append(hover_text)
fig.add_shape(
type="rect",
x0=start_times[0],
x1=end_times[-1],
y0=confidence_threshold,
y1=100,
fillcolor="rgba(188,223,241,0.8)",
opacity=0.8,
layer="below",
line_width=0,
)
fig.add_hline(y=confidence_threshold, line_dash="dash", line_color="black", line_width=1)
fig.add_trace(
go.Scatter(
x=mid_times,
y=heights,
mode='lines',
name='Disregolazione',
line=dict(
color='orange',
width=2,
shape='spline', # This enables smoothing
smoothing=1.0, # Adjust smoothing factor
),
text=hover_texts,
hoverinfo='text',
showlegend=False,
)
)
emotion_data = {}
for i, height in enumerate(heights):
if height >= confidence_threshold:
emotion = emotions[i]
if emotion not in emotion_data:
emotion_data[emotion] = {
'times': [],
'heights': [],
'hover_texts': []
}
emotion_data[emotion]['times'].append(mid_times[i])
emotion_data[emotion]['heights'].append(height)
emotion_data[emotion]['hover_texts'].append(hover_texts[i])
for emotion in emotion_order:
color = COLOR_MAP.get(emotion, '#000000')
if emotion in emotion_data:
data = emotion_data[emotion]
fig.add_trace(
go.Scatter(
x=data['times'],
y=data['heights'],
mode='markers',
name=emotion.capitalize(),
marker=dict(
size=15,
color=color,
symbol='circle'
),
text=data['hover_texts'],
hoverinfo='text',
showlegend=True,
)
)
else:
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode='markers',
name=emotion.capitalize(),
marker=dict(
size=15,
color=color,
symbol='circle'
),
showlegend=True,
)
)
fig.update_layout(
title='Distribuzione della disregolazione',
xaxis_title='Tempo',
yaxis_title='Attendibilità',
xaxis=dict(
type='date',
tickformat='%H:%M:%S',
showline=True,
zeroline=False,
side='bottom',
showgrid=False,
),
yaxis=dict(
range=[0, 100],
tickvals=[0, 20, 40, 60, 80, 100],
ticktext=['0%', '20%', '40%', '60%', '80%', '100%'],
tickmode='array',
showgrid=False,
),
legend_title=None,
legend=dict(
yanchor="top"
),
hoverlabel=dict(
font_size=12,
font_family="Arial"
),
paper_bgcolor='white',
plot_bgcolor='white',
)
fig.update_traces(hovertemplate=None)
return fig