|
const TASK_ID_TO_NAME = { |
|
'agg_score': 'Aggregate Score', |
|
'commonsense_qa/acc_norm': 'Commonsense QA', |
|
'hellaswag/acc_norm': 'HellaSwag', |
|
'openbookqa/acc_norm': 'OpenBook QA', |
|
'piqa/acc_norm': 'PIQA', |
|
'winogrande/acc_norm': 'WinoGrande', |
|
'arc/acc_norm': 'ARC', |
|
'mmlu/acc_norm': 'MMLU' |
|
}; |
|
|
|
|
|
const DEFAULT_LAYOUT = { |
|
title: { |
|
text: 'Plot Title', |
|
font: { |
|
size: 19, |
|
family: "apple-system, Arial, sans-serif" |
|
} |
|
}, |
|
xaxis: { |
|
title: { |
|
text: 'Training tokens (billions)', |
|
font: { |
|
size: 15, |
|
family: "apple-system, Arial, sans-serif" |
|
|
|
} |
|
}, |
|
tickfont: { |
|
size: 14, |
|
family: "apple-system, Arial, sans-serif" |
|
}, |
|
showgrid: false, |
|
mirror: true, |
|
ticks: 'outside', |
|
showline: true, |
|
}, |
|
yaxis: { |
|
title: { |
|
text: "Agg Score", |
|
font: { |
|
size: 15, |
|
family: "apple-system, Arial, sans-serif" |
|
}, |
|
standoff: 10 |
|
}, |
|
showgrid: false, |
|
mirror: true, |
|
ticks: 'outside', |
|
showline: true, |
|
tickfont: { |
|
size: 14, |
|
family: "apple-system, Arial, sans-serif" |
|
}, |
|
}, |
|
legend: { |
|
orientation: 'v', |
|
xanchor: 'right', |
|
yanchor: 'bottom', |
|
x: 1, |
|
y: 0, |
|
font: { |
|
size: 14, |
|
family: "apple-system, Arial, sans-serif" |
|
}, |
|
bgcolor: 'rgba(0,0,0,0)', |
|
}, |
|
margin: { |
|
t: 30, |
|
b: 50 |
|
}, |
|
height: 400 |
|
} |
|
|
|
|
|
|
|
const init_plot = function() { |
|
const plotElements = document.querySelectorAll('[id^="plot-"]'); |
|
|
|
plotElements.forEach(async (plotElement) => { |
|
const plotName = plotElement.id.replace('plot-', ''); |
|
const data = await fetch(`data/plots/${plotName}.json`).then((response) => response.json()); |
|
const {dropdown, slider, plot} = createPlottingElements(plotElement, data.data ?? data.traces, data.defaultMetric ?? "agg_score", data.defaultWindowSize ?? 0, data.createSlider ?? 1); |
|
plot.id = `graph-${plotName}`; |
|
dropdown.addEventListener('change', () => updatePlot(dropdown, slider)); |
|
let timeoutId; |
|
|
|
if (slider) |
|
slider.addEventListener('input', () => { |
|
clearTimeout(timeoutId); |
|
timeoutId = setTimeout(() => { |
|
updatePlot(dropdown, slider); |
|
}, 500); |
|
}); |
|
|
|
function updatePlot(dropdown, slider) { |
|
const metric = dropdown.value; |
|
const sliderValue = parseInt(slider?.value ?? 0); |
|
const traces = "traces" in data ? data.traces[metric] : []; |
|
if (!("traces" in data)) { |
|
const metricData = data.data[metric]; |
|
for (const key in metricData) { |
|
const y = rollingWindow(metricData[key].y, sliderValue); |
|
const x = metricData[key].x.slice(0, y.length); |
|
const trace = { |
|
x: x, |
|
y: y, |
|
type: 'scatter', |
|
mode: 'lines', |
|
line: { |
|
width: 2.5 |
|
}, |
|
name: metricData[key].label |
|
}; |
|
traces.push(trace); |
|
} |
|
} |
|
let minX = Math.min(...traces.flatMap(trace => trace.x)); |
|
let maxX = Math.max(...traces.flatMap(trace => trace.x)); |
|
const width = plot.parentElement.offsetWidth; |
|
const layout = _.merge({}, DEFAULT_LAYOUT, {width: width, yaxis: {title: {text: TASK_ID_TO_NAME[metric]}}, xaxis: {range: [minX*0.95, maxX*1.05]}}, data.layout); |
|
Plotly.newPlot(plot, traces, layout); |
|
|
|
window.addEventListener('resize', () => { |
|
|
|
if (window.innerWidth < 768) { |
|
return; |
|
} |
|
|
|
console.log(plot.parentElement.offsetWidth); |
|
console.log(plot.id); |
|
Plotly.relayout(plot, {width: plot.parentElement.offsetWidth}); |
|
}) |
|
|
|
} |
|
|
|
updatePlot(dropdown, slider); |
|
}); |
|
}; |
|
document.addEventListener('DOMContentLoaded', init_plot); |
|
|
|
|
|
const getSliderMax = (data) => { |
|
const firstMetricData = data[Object.keys(data)[0]] |
|
const totalSamples = firstMetricData[Object.keys(firstMetricData)[0]].x.length |
|
console.log(totalSamples); |
|
if (totalSamples < 20) { |
|
return 10; |
|
} |
|
|
|
return 30; |
|
} |
|
|
|
const createPlottingElements = (plotElement, data, defaultMetric, defaultWindowSize, createSlider) => { |
|
|
|
const plot = document.createElement('figure'); |
|
const controls = document.createElement('div'); |
|
plot.classList.add('plotly'); |
|
controls.classList.add('plotly_controls'); |
|
plotElement.appendChild(plot); |
|
plotElement.appendChild(controls); |
|
|
|
|
|
const metricOptions = Object.keys(data).filter(metric => metric in TASK_ID_TO_NAME); |
|
|
|
const dropdownLabel = document.createElement('label'); |
|
dropdownLabel.textContent = 'Metric:'; |
|
const dropdown = document.createElement('select'); |
|
dropdown.innerHTML = metricOptions.map((option) => `<option value="${option}">${TASK_ID_TO_NAME[option]}</option>`).join(''); |
|
dropdown.value = defaultMetric; |
|
const dropdownContainer = document.createElement('div'); |
|
dropdownContainer.classList.add('plotly_input_container'); |
|
dropdownContainer.appendChild(dropdownLabel); |
|
dropdownContainer.appendChild(dropdown); |
|
controls.appendChild(dropdownContainer); |
|
|
|
if (!createSlider) |
|
return {dropdown, undefined, plot}; |
|
|
|
const sliderLabel = document.createElement('label'); |
|
sliderLabel.textContent = 'Rolling window:'; |
|
const slider = document.createElement('input'); |
|
slider.type = 'range'; |
|
slider.min = 0; |
|
slider.max = getSliderMax(data); |
|
slider.value = defaultWindowSize ?? 0; |
|
|
|
|
|
|
|
|
|
|
|
|
|
const sliderValue = document.createElement('span'); |
|
sliderValue.textContent = slider.value; |
|
slider.addEventListener('input', () => { |
|
sliderValue.textContent = slider.value; |
|
}); |
|
const sliderInputContainer = document.createElement('div'); |
|
sliderInputContainer.classList.add('plotly_slider'); |
|
sliderInputContainer.appendChild(slider); |
|
sliderInputContainer.appendChild(sliderValue); |
|
|
|
|
|
const sliderContainer = document.createElement('div'); |
|
sliderContainer.classList.add('plotly_input_container'); |
|
|
|
|
|
sliderContainer.appendChild(sliderLabel); |
|
sliderContainer.appendChild(sliderInputContainer); |
|
controls.appendChild(sliderContainer); |
|
|
|
|
|
return {dropdown, slider, plot}; |
|
} |
|
|
|
const rollingWindow = function(data, windowSize) { |
|
if (windowSize === 0) { |
|
return data; |
|
} |
|
const rollingData = []; |
|
|
|
|
|
for (let i = windowSize; i < data.length; i++) { |
|
const windowStart = i - windowSize; |
|
const windowEnd = i; |
|
const windowData = data.slice(windowStart, windowEnd); |
|
|
|
const windowAverage = windowData.reduce((acc, value) => acc + value, 0) / windowData.length; |
|
rollingData.push(windowAverage); |
|
} |
|
|
|
return rollingData; |
|
} |
|
|