|
import time |
|
import plotly.graph_objects as go |
|
from datetime import datetime, timedelta |
|
|
|
|
|
SAMPLING_RATE = 16_000 |
|
COLOR_MAP = { |
|
"Neutralità": "rgb(178, 178, 178)", |
|
"Rabbia": "rgb(160, 61, 62)", |
|
"Paura": "rgb(91, 57, 136)", |
|
"Gioia": "rgb(255, 255, 0)", |
|
"Sorpresa": "rgb(60, 175, 175)", |
|
"Tristezza": "rgb(64, 106, 173)", |
|
"Disgusto": "rgb(100, 153, 65)", |
|
} |
|
|
|
def create_behaviour_gantt_plot(behaviour_chunks, confidence_threshold=60): |
|
print("Creating behaviour Gantt plot...") |
|
emotion_order = [ |
|
"Gioia", |
|
"Sorpresa", |
|
"Disgusto", |
|
"Tristezza", |
|
"Paura", |
|
"Rabbia", |
|
"Neutralità" |
|
] |
|
|
|
fig = go.Figure() |
|
|
|
chunk_starts = [start/SAMPLING_RATE for start, _, _, _, _ in behaviour_chunks] |
|
chunk_ends = [end/SAMPLING_RATE for _, end, _, _, _ in behaviour_chunks] |
|
|
|
|
|
|
|
base_time = datetime(2_000, 1, 1, 0, 0, 0) |
|
|
|
start_times = [base_time + timedelta(seconds=t) for t in chunk_starts] |
|
end_times = [base_time + timedelta(seconds=t) for t in chunk_ends] |
|
|
|
|
|
mid_times = [base_time + timedelta(seconds=(s+e)/2) for s, e in zip(chunk_starts, chunk_ends)] |
|
|
|
heights = [height * 100 for _, _, _, height, _ in behaviour_chunks] |
|
|
|
emotions = [emotion for _, _, _, _, emotion in behaviour_chunks] |
|
|
|
hover_texts = [] |
|
for i, (start, end, label, height, emotion) in enumerate(behaviour_chunks): |
|
start_fmt = time.strftime('%H:%M:%S', time.gmtime(start / SAMPLING_RATE)) |
|
end_fmt = time.strftime('%H:%M:%S', time.gmtime(end / SAMPLING_RATE)) |
|
duration_seconds = (end - start) / SAMPLING_RATE |
|
duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_seconds)) |
|
|
|
hover_text = f"Inizio: {start_fmt}<br>Fine: {end_fmt}<br>Durata: {duration_str}<br>Testo: {label}<br>Attendibilità: {height*100:.2f}%<br>Emozione: {emotion}" |
|
hover_texts.append(hover_text) |
|
|
|
fig.add_shape( |
|
type="rect", |
|
x0=start_times[0], |
|
x1=end_times[-1], |
|
y0=confidence_threshold, |
|
y1=100, |
|
fillcolor="rgba(188,223,241,0.8)", |
|
opacity=0.8, |
|
layer="below", |
|
line_width=0, |
|
) |
|
|
|
fig.add_hline(y=confidence_threshold, line_dash="dash", line_color="black", line_width=1) |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=mid_times, |
|
y=heights, |
|
mode='lines', |
|
name='Disregolazione', |
|
line=dict( |
|
color='orange', |
|
width=2, |
|
shape='spline', |
|
smoothing=1.0, |
|
), |
|
text=hover_texts, |
|
hoverinfo='text', |
|
showlegend=False, |
|
) |
|
) |
|
|
|
emotion_data = {} |
|
|
|
for i, height in enumerate(heights): |
|
if height >= confidence_threshold: |
|
emotion = emotions[i] |
|
if emotion not in emotion_data: |
|
emotion_data[emotion] = { |
|
'times': [], |
|
'heights': [], |
|
'hover_texts': [] |
|
} |
|
|
|
emotion_data[emotion]['times'].append(mid_times[i]) |
|
emotion_data[emotion]['heights'].append(height) |
|
emotion_data[emotion]['hover_texts'].append(hover_texts[i]) |
|
|
|
for emotion in emotion_order: |
|
color = COLOR_MAP.get(emotion, '#000000') |
|
|
|
if emotion in emotion_data: |
|
data = emotion_data[emotion] |
|
fig.add_trace( |
|
go.Scatter( |
|
x=data['times'], |
|
y=data['heights'], |
|
mode='markers', |
|
name=emotion.capitalize(), |
|
marker=dict( |
|
size=15, |
|
color=color, |
|
symbol='circle' |
|
), |
|
text=data['hover_texts'], |
|
hoverinfo='text', |
|
showlegend=True, |
|
) |
|
) |
|
else: |
|
fig.add_trace( |
|
go.Scatter( |
|
x=[None], |
|
y=[None], |
|
mode='markers', |
|
name=emotion.capitalize(), |
|
marker=dict( |
|
size=15, |
|
color=color, |
|
symbol='circle' |
|
), |
|
showlegend=True, |
|
) |
|
) |
|
|
|
fig.update_layout( |
|
title='Distribuzione della disregolazione', |
|
xaxis_title='Tempo', |
|
yaxis_title='Attendibilità', |
|
xaxis=dict( |
|
type='date', |
|
tickformat='%H:%M:%S', |
|
showline=True, |
|
zeroline=False, |
|
side='bottom', |
|
showgrid=False, |
|
), |
|
yaxis=dict( |
|
range=[0, 100], |
|
tickvals=[0, 20, 40, 60, 80, 100], |
|
ticktext=['0%', '20%', '40%', '60%', '80%', '100%'], |
|
tickmode='array', |
|
showgrid=False, |
|
), |
|
legend_title=None, |
|
legend=dict( |
|
yanchor="top" |
|
), |
|
hoverlabel=dict( |
|
font_size=12, |
|
font_family="Arial" |
|
), |
|
paper_bgcolor='white', |
|
plot_bgcolor='white', |
|
) |
|
|
|
fig.update_traces(hovertemplate=None) |
|
|
|
return fig |