Spaces:
Runtime error
Runtime error
from bokeh.events import Tap | |
from bokeh.io import curdoc | |
from bokeh.layouts import column | |
from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs | |
from bokeh.models.tools import CrosshairTool | |
from demo_utils import ( | |
get_data, | |
prompt_boolq, | |
pvp_colors, | |
ctl_colors, | |
clf_colors, | |
reduct, | |
task_best_pattern, | |
plot_polygons_bokeh, | |
advantage_text, | |
data_difference, | |
calculate_overlap, | |
circ_easing, | |
average_advantage_text, | |
plot_three_polygons_bokeh, | |
tasks, | |
metric_tap, | |
neutral_tasks, pattern_graph, | |
) | |
from text import text1, text2, text3, text4, initial_passage, initial_question, text5 | |
######################################################################################################################## | |
# Basic dimensions | |
######################################################################################################################## | |
plot_width = 1200 | |
plot_height = 400 | |
sidebar_width = 400 | |
in_text_plot_height = 300 | |
text_width = 800 | |
widget_size = 400 | |
######################################################################################################################## | |
# Patternification widget | |
######################################################################################################################## | |
passage = TextAreaInput(title="篇章", rows=3, value=initial_passage, max_width=text_width) | |
passage.align = "center" | |
question = TextInput(title="问题", value=initial_question, max_width=text_width) | |
question.align = "center" | |
radio_button_group = RadioButtonGroup(labels=["模板 1", "模板 2", "模板 3"], active=0, max_width=text_width) | |
radio_button_group.align = "center" | |
box_style = { | |
"display": "block", | |
"margin": "0 auto", | |
"width": f"{text_width}px", | |
"text-align": "center", | |
"white-space": "pre-wrap", | |
"background": "#f4f4f4", | |
"border": "1px solid #ddd", | |
# "border-left": "3px solid #4d4945", | |
"color": "#666", | |
"page-break-inside": "avoid", | |
# "font-family": "monospace", | |
"font-size": "15px", | |
"line-height": "1.6", | |
"max-width": "100%", | |
"overflow": "hidden", | |
"min-height": "30px", | |
"word-wrap": "break-word", | |
} | |
prompt_box = Div( | |
text=prompt_boolq(passage.value, question.value, radio_button_group.active), | |
width=text_width, | |
style=box_style, | |
sizing_mode="scale_width", | |
) | |
prompt_box.align = "center" | |
def update_prompt(attrname, old, new): | |
prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active) | |
passage.on_change("value", update_prompt) | |
question.on_change("value", update_prompt) | |
radio_button_group.on_change("active", update_prompt) | |
patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width") | |
patternification.align = "center" | |
######################################################################################################################## | |
# Advantage diagram | |
######################################################################################################################## | |
advantage_plots_per_task = [] | |
overlapping_range_per_task = [] | |
training_points_per_task = [] | |
clf_results_per_task = [] | |
pvp_results_per_task = [] | |
advantage_tabs = [] | |
advantage_all_figures = Tabs(tabs=advantage_tabs) | |
advantage_box = Div( | |
text="Click within the comparison region to compute the data advantage for a performance level", | |
width=text_width, | |
style=box_style, | |
sizing_mode="scale_width", | |
) | |
advantage_box.align = "center" | |
for task in tasks: | |
training_points, classifier_performances, pattern_performances = get_data(task) | |
training_points_per_task.append(list(training_points)) | |
clf_results_per_task.append(reduct(classifier_performances, "accmax")) | |
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")) | |
advantage_plots_per_task.append(plot_polygons_bokeh( | |
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors, | |
pvp_colors | |
)) | |
advantage_plots_per_task[-1].align = "center" | |
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1])) | |
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task)) | |
advantage_plots_per_task[-1].on_event( | |
Tap, | |
lambda event: metric_tap( | |
event, | |
overlapping_range_per_task[advantage_all_figures.active], | |
training_points_per_task[advantage_all_figures.active], | |
clf_results_per_task[advantage_all_figures.active], | |
pvp_results_per_task[advantage_all_figures.active], | |
advantage_box, | |
advantage_plots_per_task[advantage_all_figures.active], | |
), | |
) | |
if task == "MNLI": | |
training_points_per_task.append(list(training_points)) | |
clf_results_per_task.append(reduct(classifier_performances, "accmax")) | |
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")) | |
advantage_plots_per_task.append(plot_polygons_bokeh( | |
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors, | |
pvp_colors, x_log_scale=True | |
)) | |
advantage_plots_per_task[-1].align = "center" | |
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1])) | |
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)")) | |
advantage_plots_per_task[-1].on_event( | |
Tap, | |
lambda event: metric_tap( | |
event, | |
overlapping_range_per_task[advantage_all_figures.active], | |
training_points_per_task[advantage_all_figures.active], | |
clf_results_per_task[advantage_all_figures.active], | |
pvp_results_per_task[advantage_all_figures.active], | |
advantage_box, | |
advantage_plots_per_task[advantage_all_figures.active], | |
), | |
) | |
advantage_all_figures = Tabs(tabs=advantage_tabs) | |
advantage_all_figures.align = "center" | |
def on_integrate_click(): | |
frames = 200 | |
initial_placement = overlapping_range_per_task[advantage_all_figures.active][0] | |
if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span): | |
metric_line = Span( | |
location=initial_placement, | |
line_alpha=0.7, | |
dimension="width", | |
line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0], | |
line_dash="dashed", | |
line_width=1, | |
) | |
advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line]) | |
else: | |
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement | |
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[ | |
0] if initial_placement < 0 else pvp_colors[0] | |
average_advantage = 0 | |
for i in range(1, frames): | |
metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + ( | |
overlapping_range_per_task[advantage_all_figures.active][1] - | |
overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames) | |
advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active], | |
training_points_per_task[advantage_all_figures.active], | |
clf_results_per_task[advantage_all_figures.active], | |
pvp_results_per_task[advantage_all_figures.active]) | |
average_advantage = ((i - 1) * average_advantage + advantage_value) / i | |
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value | |
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[ | |
0] if advantage_value < 0 else pvp_colors[0] | |
advantage_box.text = average_advantage_text(average_advantage) | |
integrate = Button(width=175, max_width=175, label="Integrate over the whole region!") | |
integrate.align = "center" | |
integrate.on_click(on_integrate_click) | |
def on_tab_change(attr, old, new): | |
advantage_box.text = "Click within the comparison region to compute the data advantage for a performance level" | |
advantage_all_figures.on_change('active', on_tab_change) | |
advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width") | |
######################################################################################################################## | |
# Null verbalizer diagram | |
######################################################################################################################## | |
null_tabs = [] | |
null_all_figures = Tabs(tabs=null_tabs) | |
for task in neutral_tasks: | |
training_points, classifier_performances, pattern_performances = get_data(task) | |
training_points = list(training_points) | |
clf_results = reduct(classifier_performances, "accmax") | |
pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal") | |
ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral") | |
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors, | |
pvp_colors, ctl_colors) | |
null_plot.align = "center" | |
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
null_tabs.append(Panel(child=null_plot, title=task)) | |
if task == "MNLI": | |
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors, | |
pvp_colors, ctl_colors, x_log_scale=True) | |
null_plot.align = "center" | |
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)")) | |
null_all_figures = Tabs(tabs=null_tabs) | |
null_all_figures.align = "center" | |
######################################################################################################################## | |
# Patterns diagram | |
######################################################################################################################## | |
pattern_tabs = [] | |
pattern_all_figures = Tabs(tabs=pattern_tabs) | |
for task in tasks: | |
pattern_plot = pattern_graph(task) | |
pattern_plot.align = "center" | |
pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
pattern_tabs.append(Panel(child=pattern_plot, title=task)) | |
pattern_all_figures = Tabs(tabs=pattern_tabs) | |
pattern_all_figures.align = "center" | |
######################################################################################################################## | |
# Add write-up text | |
######################################################################################################################## | |
main_text_style = { | |
"min-height": "100px", | |
"overflow": "hidden", | |
"display": "block", | |
"margin": "auto", | |
"width": f"{text_width}px", | |
"font-size": "18px", | |
} | |
textbox1 = Div(text=text1, style=main_text_style) | |
textbox2 = Div(text=text2, style=main_text_style) | |
textbox3 = Div(text=text3, style=main_text_style) | |
textbox4 = Div(text=text4, style=main_text_style) | |
textbox5 = Div(text=text5, style=main_text_style) | |
textbox1.align = "center" | |
textbox2.align = "center" | |
textbox3.align = "center" | |
textbox4.align = "center" | |
textbox5.align = "center" | |
######################################################################################################################## | |
# Set up layouts and add to document | |
######################################################################################################################## | |
main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures, textbox5, sizing_mode="scale_width") | |
main_body.align = "center" | |
curdoc().add_root(main_body) | |
curdoc().title = "一条提示抵得上多少样本数据?" | |