|
import Plotly from 'plotly.js-basic-dist-min'; |
|
import Papa from 'papaparse'; |
|
import _ from 'lodash'; |
|
import { getColor } from './colors.mjs'; |
|
|
|
const languageMap = { |
|
'Arabic': 'ar', |
|
'Turkish': 'tr', |
|
'Swahili': 'sw', |
|
'Russian': 'ru', |
|
'Telugu': 'te', |
|
'Thai': 'th', |
|
'Chinese': 'zh', |
|
'French': 'fr', |
|
'Hindi': 'hi' |
|
}; |
|
|
|
const runNameMap = { |
|
"orion": "Dataset-A", |
|
"helios": "Dataset-B", |
|
"lynx": "Dataset-C", |
|
"aquila": "Dataset-D", |
|
"commoncrawl": "CommonCrawl", |
|
"baseline": "Baseline" |
|
}; |
|
|
|
const taskLists = { |
|
ar: ['acva_ara:_average', 'alfgahafa_mlqa_ara_cf', 'alghafa_arc_ara_cf:easy', 'alghafa_facts_ara_cf', 'alghafa_meta_dialects_ara_cf', 'alghafa_mmlu_ara_cf:_average', 'alghafa_openbookqa_ara_cf', 'alghafa_piqa_ara_cf', 'alghafa_race_ara_cf', 'alghafa_rating_sentiment_ara_cf', 'alghafa_rating_sentiment_no_neutral_ara_cf', 'alghafa_sciqa_ara_cf', 'alghafa_sentiment_ara_cf', 'arcd_ara', 'belebele_arb_Arab_cf', 'boolq_ara', 'exams_ara_cf:_average', 'mkqa_ara:_average', 'mlmm_arc_ara_cf:challenge', 'mlmm_hellaswag_ara_cf', 'mlmm_mmlu_ara_cf:_average', 'mlmm_truthfulqa_ara_cf:mc1', 'mlmm_truthfulqa_ara_cf:mc2', 'mlqa_ara', 'mmlu_ara_cf:_average', 'soqal_ara_cf', 'toxigen_ara_cf', 'tydiqa_ara', 'xcodah_ara_cf', 'xcopa_ara_cf', 'xcsqa_ara_cf', 'xnli2.0_ara_cf', 'xnli_ara_cf', 'xquad_ara', 'xstory_cloze_ara_cf'], |
|
fr: ['belebele_fra_Latn_cf', 'community_boolq_fra_cf', 'exams_fra_cf:_average', 'fquadv2_fra', 'frenchbench_arc_fra_cf:challenge', 'frenchbench_hellaswag_fra_cf', 'meta_mmlu_fra_cf:_average', 'mintaka_fra', 'mkqa_fra:_average', 'mlmm_arc_fra_cf:challenge', 'mlmm_hellaswag_fra_cf', 'mlmm_mmlu_fra_cf:_average', 'mlmm_truthfulqa_fra_cf:mc1', 'mlmm_truthfulqa_fra_cf:mc2', 'pawsx_fra_cf', 'xcodah_fra_cf', 'xcsqa_fra_cf', 'xnli2.0_fra_cf', 'xwinograd_fra_cf'], |
|
hi: ['belebele_hin_Deva_cf', 'community_arc_hin_cf:challenge', 'community_arc_hin_cf:easy', 'community_boolq_hin', 'community_hellaswag_hin_cf', 'indicnxnli_hin_cf', 'indicqa_hin', 'indicxcopa_hin_cf', 'meta_mmlu_hin_cf:_average', 'mintaka_hin', 'mlmm_arc_hin_cf:challenge', 'mlmm_hellaswag_hin_cf', 'mlmm_mmlu_hin_cf:_average', 'mlmm_truthfulqa_hin_cf:mc1', 'mlmm_truthfulqa_hin_cf:mc2', 'mlqa_hin', 'xcodah_hin_cf', 'xcsqa_hin_cf', 'xnli2.0_hin_cf', 'xnli_hin_cf', 'xquad_hin', 'xstory_cloze_hin_cf'], |
|
ru: ['belebele_rus_Cyrl_cf', 'chegeka_rus', 'mathlogic_qa_rus_cf', 'mera_openbookqa_rus_cf', 'mera_worldtree_rus_cf', 'mkqa_rus:_average', 'mlmm_arc_rus_cf:challenge', 'mlmm_hellaswag_rus_cf', 'mlmm_mmlu_rus_cf:_average', 'mlmm_truthfulqa_rus_cf:mc1', 'mlmm_truthfulqa_rus_cf:mc2', 'parus_rus_cf', 'rcb_rus_cf', 'rummlu_rus_cf:_average', 'sber_squad_rus', 'tydiqa_rus', 'xcodah_rus_cf', 'xcsqa_rus_cf', 'xnli2.0_rus_cf', 'xquad_rus', 'xstory_cloze_rus_cf', 'xwinograd_rus_cf'], |
|
sw: ['afric_mmlu_swa_cf:_average', 'afric_xnli_swa_cf', 'belebele_swh_Latn_cf', 'community_arc_swa_cf:challenge', 'community_arc_swa_cf:easy', 'community_mmlu_swa_cf', 'kenswquad_swa', 'm3exams_swa_cf', 'openai_mmlu_swa_cf:_average', 'tydiqa_swa', 'xcodah_swa_cf', 'xcopa_swa_cf', 'xcsqa_swa_cf', 'xnli2.0_swa_cf', 'xnli_swa_cf', 'xstory_cloze_swa_cf'], |
|
te: ['belebele_tel_Telu_cf', 'community_hellaswag_tel_cf', 'indicnxnli_tel_cf', 'indicqa_tel', 'indicxcopa_tel_cf', 'mlmm_arc_tel_cf:challenge', 'mlmm_hellaswag_tel_cf', 'mlmm_mmlu_tel_cf:_average', 'mlmm_truthfulqa_tel_cf:mc1', 'mlmm_truthfulqa_tel_cf:mc2', 'tydiqa_tel', 'xstory_cloze_tel_cf'], |
|
th: ['belebele_tha_Thai_cf', 'community_hellaswag_tha_cf', 'm3exams_tha_cf', 'meta_mmlu_tha_cf:_average', 'mkqa_tha:_average', 'thai_exams_tha_cf:_average', 'thai_exams_tha_cf:tgat', 'thaiqa_tha', 'wsci_tha_cf', 'xcopa_tha_cf', 'xnli2.0_tha_cf', 'xnli_tha_cf', 'xquad_tha'], |
|
tr: ['belebele_tur_Latn_cf', 'community_arc_tur_cf:easy', 'community_hellaswag_tur_cf', 'community_mmlu_tur_cf:_average', 'community_truthfulqa_tur_cf:mc1', 'community_truthfulqa_tur_cf:mc2', 'community_xwinograd_tur_cf', 'exams_tur_cf:_average', 'mkqa_tur:_average', 'tquadv2_tur', 'xcopa_tur_cf', 'xnli2.0_tur_cf', 'xnli_tur_cf', 'xquad_tur'], |
|
zh: ['agieval_zho_cf:_average', 'belebele_zho_Hans_cf', 'c3_zho_cf', 'ceval_zho_cf:_average', 'chinese_squad_zho', 'cmath_zho_cf', 'cmmlu_zho_cf:_average', 'cmnli_zho_cf', 'cmrc2018_zho', 'm3exams_zho_cf', 'mkqa_zho:_average', 'mlmm_arc_zho_cf:challenge', 'mlmm_hellaswag_zho_cf', 'mlmm_mmlu_zho_cf:_average', 'mlmm_truthfulqa_zho_cf:mc1', 'mlmm_truthfulqa_zho_cf:mc2', 'ocnli_zho_cf', 'pawsx_zho_cf', 'xcodah_zho_cf', 'xcopa_zho_cf', 'xcsqa_zho_cf', 'xnli2.0_zho_cf', 'xnli_zho_cf', 'xquad_zho', 'xstory_cloze_zho_cf', 'xwinograd_zho_cf'] |
|
}; |
|
|
|
const LINE_SETTINGS = { |
|
width: 2.5, |
|
type: "scatter", |
|
mode: "lines+markers", |
|
}; |
|
|
|
const DEFAULT_LAYOUT = { |
|
font: { |
|
family: "apple-system, Arial, sans-serif", |
|
}, |
|
title: { |
|
font: { |
|
size: 15, |
|
}, |
|
}, |
|
xaxis: { |
|
title: { |
|
text: "Training Tokens (billions)", |
|
font: { |
|
size: 14, |
|
}, |
|
}, |
|
tickfont: { |
|
size: 12, |
|
}, |
|
showgrid: false, |
|
mirror: true, |
|
ticks: "outside", |
|
showline: true, |
|
}, |
|
yaxis: { |
|
title: { |
|
font: { |
|
size: 14, |
|
}, |
|
standoff: 10, |
|
}, |
|
showgrid: false, |
|
mirror: true, |
|
ticks: "outside", |
|
showline: true, |
|
tickfont: { |
|
size: 12, |
|
}, |
|
}, |
|
height: 300, |
|
autosize: true, |
|
legend: { |
|
orientation: 'h', |
|
yanchor: 'bottom', |
|
y: 0, |
|
xanchor: 'right', |
|
x: 1, |
|
traceorder: 'normal', |
|
font: { size: 12 }, |
|
tracegroupgap: 0, |
|
bgcolor: 'rgba(255, 255, 255, 0.8)' |
|
}, |
|
margin: { |
|
t: 25, |
|
b: 60, |
|
l: 60, |
|
r: 40, |
|
}, |
|
}; |
|
|
|
export function initPlotApplets() { |
|
const plotContainers = document.querySelectorAll('.task-signal-plot'); |
|
plotContainers.forEach(container => { |
|
initPlotApplet(container); |
|
}); |
|
} |
|
|
|
function initPlotApplet(container) { |
|
const defaultLanguage = container.dataset.language || 'Arabic'; |
|
const defaultTask = container.dataset.task || ''; |
|
const defaultMetric = container.dataset.metric || ''; |
|
const groupSeeds = container.dataset.groupSeeds === 'true'; |
|
const showControls = container.dataset.showControls === 'true'; |
|
const taskMetrics = (container.dataset.taskMetrics || 'monotonicity,snr,ordering,randomness').split(","); |
|
|
|
const controls = createControls(container, defaultLanguage, defaultTask, defaultMetric, taskMetrics); |
|
if (!showControls) |
|
controls.style.display = 'none'; |
|
container.appendChild(controls); |
|
|
|
const plotContainer = document.createElement('div'); |
|
plotContainer.className = 'plot-container'; |
|
container.appendChild(plotContainer); |
|
|
|
const statsContainer = document.createElement('div'); |
|
statsContainer.className = 'stats-container'; |
|
container.appendChild(statsContainer); |
|
|
|
|
|
|
|
Plotly.newPlot(plotContainer, []); |
|
|
|
|
|
const resizePlot = () => { |
|
const width = container.offsetWidth; |
|
Plotly.relayout(plotContainer, { width: width }); |
|
}; |
|
|
|
|
|
window.addEventListener('resize', resizePlot); |
|
|
|
|
|
resizePlot(); |
|
|
|
|
|
updateLanguageTasks(container, defaultTask, defaultMetric, groupSeeds, taskMetrics); |
|
} |
|
|
|
function createControls(container, defaultLanguage, defaultTask, defaultMetric, taskMetrics) { |
|
const controls = document.createElement('div'); |
|
controls.className = 'controls'; |
|
|
|
const languageSelect = createSelect('language', Object.keys(languageMap), () => updateLanguageTasks(container, '', '', true, taskMetrics)); |
|
languageSelect.value = defaultLanguage; |
|
|
|
const taskSelect = createSelect('task', [], () => updateMetrics(container, '', true, taskMetrics)); |
|
const metricSelect = createSelect('metric', [], () => updatePlot(container, taskMetrics)); |
|
|
|
controls.appendChild(createControlGroup('Language:', languageSelect)); |
|
controls.appendChild(createControlGroup('Task:', taskSelect)); |
|
controls.appendChild(createControlGroup('Metric:', metricSelect)); |
|
|
|
return controls; |
|
} |
|
|
|
function createSelect(id, options, onChangeHandler) { |
|
const select = document.createElement('select'); |
|
select.id = id; |
|
options.forEach(option => { |
|
const optionElement = document.createElement('option'); |
|
optionElement.value = option; |
|
optionElement.textContent = option; |
|
select.appendChild(optionElement); |
|
}); |
|
select.addEventListener('change', onChangeHandler); |
|
return select; |
|
} |
|
|
|
function createControlGroup(labelText, inputElement) { |
|
const group = document.createElement('div'); |
|
group.className = 'control-group'; |
|
|
|
const label = document.createElement('label'); |
|
label.textContent = labelText; |
|
label.className = 'control-label'; |
|
|
|
group.appendChild(label); |
|
group.appendChild(inputElement); |
|
|
|
return group; |
|
} |
|
|
|
async function updateLanguageTasks(container, defaultTask = '', defaultMetric = '', groupSeeds, taskMetrics) { |
|
const languageSelect = container.querySelector('#language'); |
|
const taskSelect = container.querySelector('#task'); |
|
const language = languageSelect.value; |
|
const langCode = languageMap[language]; |
|
|
|
taskSelect.innerHTML = '<option value="">Loading tasks...</option>'; |
|
|
|
try { |
|
const tasks = await getTasksForLanguage(langCode); |
|
|
|
taskSelect.innerHTML = ''; |
|
if (tasks.length > 0) { |
|
tasks.forEach(task => { |
|
const option = document.createElement('option'); |
|
option.value = task; |
|
option.textContent = truncateText(task, 25); |
|
option.title = task; |
|
taskSelect.appendChild(option); |
|
}); |
|
|
|
if (defaultTask && tasks.includes(defaultTask)) { |
|
taskSelect.value = defaultTask; |
|
} else { |
|
taskSelect.selectedIndex = 0; |
|
} |
|
|
|
await updateMetrics(container, defaultMetric, groupSeeds, taskMetrics); |
|
} else { |
|
taskSelect.innerHTML = '<option value="">No tasks available</option>'; |
|
clearPlot(container); |
|
} |
|
} catch (error) { |
|
console.error('Error fetching tasks:', error); |
|
taskSelect.innerHTML = '<option value="">Error loading tasks</option>'; |
|
clearPlot(container); |
|
} |
|
} |
|
|
|
async function getTasksForLanguage(langCode) { |
|
return taskLists[langCode] || []; |
|
} |
|
|
|
async function updateMetrics(container, defaultMetric = '', groupSeeds, taskMetrics) { |
|
const language = container.querySelector('#language').value; |
|
const task = container.querySelector('#task').value; |
|
const langCode = languageMap[language]; |
|
const metricSelect = container.querySelector('#metric'); |
|
|
|
metricSelect.innerHTML = '<option value="">Loading metrics...</option>'; |
|
|
|
try { |
|
const metrics = await getMetricsForTask(langCode, task); |
|
|
|
metricSelect.innerHTML = ''; |
|
metrics.forEach(metric => { |
|
const option = document.createElement('option'); |
|
option.value = metric; |
|
option.textContent = metric; |
|
metricSelect.appendChild(option); |
|
}); |
|
|
|
if (defaultMetric && metrics.includes(defaultMetric)) { |
|
metricSelect.value = defaultMetric; |
|
} else if (metricSelect.options.length > 0) { |
|
metricSelect.selectedIndex = 0; |
|
} |
|
|
|
await updatePlot(container, taskMetrics); |
|
} catch (error) { |
|
console.error('Error fetching metrics:', error); |
|
metricSelect.innerHTML = '<option value="">Error loading metrics</option>'; |
|
clearPlot(container); |
|
} |
|
} |
|
|
|
async function getMetricsForTask(langCode, task) { |
|
return new Promise((resolve, reject) => { |
|
Papa.parse(`data/nanotron_tasks/${langCode}/${task}_stats.csv`, { |
|
download: true, |
|
header: true, |
|
complete: function(results) { |
|
const metrics = [...new Set(results.data.map(row => row.metric).filter(metric => metric))]; |
|
resolve(metrics); |
|
}, |
|
error: function(error) { |
|
console.error('Error fetching metrics:', error); |
|
reject(error); |
|
} |
|
}); |
|
}); |
|
} |
|
|
|
function updatePlot(container, taskMetrics) { |
|
const language = container.querySelector('#language').value; |
|
const task = container.querySelector('#task').value; |
|
const metric = container.querySelector('#metric').value; |
|
const title = container.dataset.title; |
|
const langCode = languageMap[language]; |
|
|
|
if (!langCode || !task || !metric) { |
|
clearPlot(container); |
|
return; |
|
} |
|
|
|
const dataUrl = `data/nanotron_tasks/${langCode}/${task}_data.csv`; |
|
const statsUrl = `data/nanotron_tasks/${langCode}/${task}_stats.csv`; |
|
|
|
Promise.all([ |
|
new Promise((resolve, reject) => { |
|
Papa.parse(dataUrl, { |
|
download: true, |
|
header: true, |
|
dynamicTyping: true, |
|
complete: resolve, |
|
error: reject |
|
}); |
|
}), |
|
new Promise((resolve, reject) => { |
|
Papa.parse(statsUrl, { |
|
download: true, |
|
header: true, |
|
dynamicTyping: true, |
|
complete: resolve, |
|
error: reject |
|
}); |
|
}) |
|
]).then(([dataResult, statsResult]) => { |
|
const taskData = dataResult.data; |
|
const statsData = statsResult.data; |
|
plotData(container, taskData, statsData, metric, title, taskMetrics); |
|
}).catch(error => { |
|
console.error('Error parsing CSV:', error); |
|
clearPlot(container); |
|
}); |
|
} |
|
|
|
function plotData(container, data, stats, metric, title, taskMetrics) { |
|
const groupSeeds = container.dataset.groupSeeds === 'true'; |
|
const sortedData = sortDataByTokens(data); |
|
const groupedData = groupDataByRunname(sortedData, groupSeeds, metric); |
|
const interpolatedData = interpolateData(groupedData, metric); |
|
const smoothedData = smoothData(interpolatedData, metric); |
|
const traces = createTraces(smoothedData, metric); |
|
|
|
const plotContainer = container.querySelector('.plot-container'); |
|
|
|
const layout = _.merge({}, DEFAULT_LAYOUT, { |
|
title: { text: `${title}` }, |
|
xaxis: { |
|
title: { text: 'Training Tokens (billions)' }, |
|
tickvals: [0, 5, 10, 15, 20, 25], |
|
ticktext: ['0', '5B', '10B', '15B', '20B', '25B'], |
|
tickangle: 45, |
|
range: [0, 30], |
|
}, |
|
yaxis: { |
|
title: { text: 'Score' }, |
|
range: [Math.min(...traces.flatMap(trace => trace.y)) * 0.95, Math.max(...traces.flatMap(trace => trace.y)) * 1.05], |
|
}, |
|
width: container.offsetWidth, |
|
}); |
|
|
|
Plotly.newPlot(plotContainer, traces, layout, {responsive: true}); |
|
|
|
|
|
displayStatistics(container, stats, metric, taskMetrics); |
|
} |
|
|
|
function displayStatistics(container, stats, metric, taskMetrics) { |
|
const statsContainer = container.querySelector('.stats-container'); |
|
const metricStats = stats.find(stat => stat.metric === metric); |
|
if (metricStats) { |
|
statsContainer.innerHTML = ` |
|
<div class="compact-stats${taskMetrics.length === 1 ? '-single' : ''}"> |
|
${taskMetrics.includes('monotonicity') ? '<span title="Average Spearman Correlation">Monotonicity: ' + metricStats.avg_spearman.toFixed(2) + '</span>' : ''} |
|
${taskMetrics.includes('snr') ? '<span title="Average Signal-to-Noise Ratio">Signal-to-Noise: ' + metricStats.avg_snr.toFixed(2) + '</span>' : ''} |
|
${taskMetrics.includes('ordering') ? '<span title="Average Kendall Tau-a">Ordering Consistency: ' + metricStats.avg_kendall_tau_a.toFixed(2) + '</span>' : ''} |
|
${taskMetrics.includes('randomness') ? '<span title="Max N Standard Deviations">Non-Randomness: ' + metricStats.max_n_std.toFixed(2) + '</span>' : ''} |
|
</div> |
|
`; |
|
} else { |
|
statsContainer.innerHTML = '<p>No statistics available for this metric.</p>'; |
|
} |
|
} |
|
|
|
function getReducedTickValues(tokens) { |
|
const uniqueTokens = [...new Set(tokens)].sort((a, b) => a - b); |
|
const tokenCount = uniqueTokens.length; |
|
const targetTickCount = 10; |
|
|
|
if (tokenCount <= targetTickCount) { |
|
return uniqueTokens; |
|
} |
|
|
|
const stride = Math.ceil(tokenCount / targetTickCount); |
|
return uniqueTokens.filter((_, index) => index % stride === 0); |
|
} |
|
|
|
function formatTickLabel(value) { |
|
if (value >= 1e9) { |
|
return (value / 1e9).toFixed(1) + 'B'; |
|
} else if (value >= 1e6) { |
|
return (value / 1e6).toFixed(1) + 'M'; |
|
} else if (value >= 1e3) { |
|
return (value / 1e3).toFixed(1) + 'K'; |
|
} |
|
return value.toString(); |
|
} |
|
|
|
function computeStatistics(data, metric) { |
|
const stats = { |
|
avg_spearman: 0, |
|
avg_kendall_tau_a: 0, |
|
avg_snr: 0, |
|
max_n_std: 0 |
|
}; |
|
|
|
const baselineRun = Object.keys(data).find(key => key.toLowerCase().includes('baseline')); |
|
const nonBaselineRuns = Object.keys(data).filter(key => key !== baselineRun); |
|
|
|
|
|
nonBaselineRuns.forEach(run => { |
|
const runData = data[run]; |
|
const tokens = runData.map(row => row.tokens); |
|
const scores = runData.map(row => row[metric]); |
|
|
|
|
|
stats.avg_spearman += spearmanCorrelation(tokens, scores); |
|
|
|
|
|
const lastHalf = Math.floor(runData.length / 2); |
|
const kendallTauValues = []; |
|
for (let i = lastHalf; i < runData.length - 1; i++) { |
|
kendallTauValues.push(kendallTauA(scores.slice(0, i + 1), scores.slice(0, i + 2))); |
|
} |
|
stats.avg_kendall_tau_a += _.mean(kendallTauValues); |
|
|
|
|
|
if (baselineRun) { |
|
const baselineScores = data[baselineRun].map(row => row[metric]); |
|
const stdDev = standardDeviation(scores); |
|
stats.avg_snr += _.mean(scores) / stdDev; |
|
stats.max_n_std = Math.max(stats.max_n_std, (_.max(scores) - _.mean(baselineScores)) / stdDev); |
|
} |
|
}); |
|
|
|
|
|
const numRuns = nonBaselineRuns.length; |
|
stats.avg_spearman /= numRuns; |
|
stats.avg_kendall_tau_a /= numRuns; |
|
stats.avg_snr /= numRuns; |
|
|
|
return stats; |
|
} |
|
|
|
function spearmanCorrelation(x, y) { |
|
const n = x.length; |
|
const rankX = rankData(x); |
|
const rankY = rankData(y); |
|
|
|
let sum_d_squared = 0; |
|
for (let i = 0; i < n; i++) { |
|
const d = rankX[i] - rankY[i]; |
|
sum_d_squared += d * d; |
|
} |
|
|
|
return 1 - (6 * sum_d_squared) / (n * (n * n - 1)); |
|
} |
|
|
|
function rankData(data) { |
|
const sorted = [...data].sort((a, b) => a - b); |
|
return data.map(x => sorted.indexOf(x) + 1); |
|
} |
|
|
|
function kendallTauA(x, y) { |
|
const n = x.length; |
|
let concordant = 0; |
|
let discordant = 0; |
|
|
|
for (let i = 0; i < n; i++) { |
|
for (let j = i + 1; j < n; j++) { |
|
const sign_x = Math.sign(x[j] - x[i]); |
|
const sign_y = Math.sign(y[j] - y[i]); |
|
if (sign_x * sign_y > 0) concordant++; |
|
else if (sign_x * sign_y < 0) discordant++; |
|
} |
|
} |
|
|
|
return (concordant - discordant) / (n * (n - 1) / 2); |
|
} |
|
|
|
function standardDeviation(values) { |
|
const mean = _.mean(values); |
|
const squareDiffs = values.map(value => { |
|
const diff = value - mean; |
|
return diff * diff; |
|
}); |
|
const avgSquareDiff = _.mean(squareDiffs); |
|
return Math.sqrt(avgSquareDiff); |
|
} |
|
|
|
function interpolateData(data, metric) { |
|
return _.mapValues(data, (rows) => { |
|
const sortedRows = _.sortBy(rows, 'tokens'); |
|
const allTokens = _.uniq(_.flatMap(Object.values(data), rows => rows.map(r => r.tokens))).sort((a, b) => a - b); |
|
|
|
return allTokens.map(token => { |
|
const exactMatch = _.find(sortedRows, { tokens: token }); |
|
if (exactMatch) return exactMatch; |
|
|
|
const lowerRow = _.findLast(sortedRows, r => r.tokens < token); |
|
const upperRow = _.find(sortedRows, r => r.tokens > token); |
|
|
|
if (!lowerRow) return { ...upperRow, tokens: token }; |
|
if (!upperRow) return { ...lowerRow, tokens: token }; |
|
|
|
const ratio = (token - lowerRow.tokens) / (upperRow.tokens - lowerRow.tokens); |
|
const interpolatedMetric = lowerRow[metric] + (upperRow[metric] - lowerRow[metric]) * ratio; |
|
|
|
return { |
|
...lowerRow, |
|
tokens: token, |
|
[metric]: interpolatedMetric |
|
}; |
|
}); |
|
}); |
|
} |
|
|
|
function smoothData(data, metric, windowSize = 3) { |
|
return _.mapValues(data, (rows) => { |
|
return rows.map((row, index, array) => { |
|
const window = array.slice(Math.max(0, index - windowSize + 1), index + 1); |
|
const smoothedMetric = _.meanBy(window, r => r[metric]); |
|
return { ...row, [metric]: smoothedMetric }; |
|
}); |
|
}); |
|
} |
|
|
|
function sortDataByTokens(data) { |
|
return _.sortBy(data, 'tokens'); |
|
} |
|
|
|
function groupDataByRunname(data, groupSeeds, metric) { |
|
|
|
data = data.filter(row => row.runname != null && row.runname !== 'null_undefined'); |
|
|
|
if (!groupSeeds) { |
|
return _.groupBy(data, row => `${processRunName(row.runname)}_${row.seed}`); |
|
} |
|
|
|
const grouped = _.groupBy(data, row => processRunName(row.runname)); |
|
|
|
return _.mapValues(grouped, (rows) => { |
|
const stepGroups = _.groupBy(rows, 'tokens'); |
|
return _.map(stepGroups, (stepRows) => { |
|
const meanMetric = _.meanBy(stepRows, row => parseFloat(row[metric]) || 0); |
|
return { |
|
...stepRows[0], |
|
[metric]: meanMetric |
|
}; |
|
}); |
|
}); |
|
} |
|
|
|
function processRunName(runname) { |
|
for (const [key, value] of Object.entries(runNameMap)) { |
|
if (runname.includes(key)) { |
|
return value; |
|
} |
|
} |
|
return runname; |
|
} |
|
|
|
function createTraces(groupedData, metric) { |
|
const colorsMapping = new Map(); |
|
const sortedRunnames = Object.keys(groupedData).sort((a, b) => { |
|
if (a.includes('baseline')) return 1; |
|
if (b.includes('baseline')) return -1; |
|
return a.localeCompare(b); |
|
}); |
|
|
|
return sortedRunnames.map((runname, index) => { |
|
const color = getColorForTrace(runname, colorsMapping, index); |
|
return { |
|
x: groupedData[runname].map(row => row.tokens), |
|
y: groupedData[runname].map(row => row[metric]), |
|
name: runname, |
|
line: { |
|
color: color, |
|
shape: 'spline', |
|
...LINE_SETTINGS |
|
}, |
|
marker: { |
|
color: color, |
|
size: 6, |
|
}, |
|
mode: 'lines+markers', |
|
}; |
|
}); |
|
} |
|
|
|
function getColorForTrace(traceName, colorsMapping, index) { |
|
const reusedColor = colorsMapping.get(traceName); |
|
if (reusedColor) { |
|
return reusedColor; |
|
} |
|
|
|
const color = getColor(index); |
|
colorsMapping.set(traceName, color); |
|
return color; |
|
} |
|
|
|
function clearPlot(container) { |
|
const plotContainer = container.querySelector('.plot-container'); |
|
Plotly.purge(plotContainer); |
|
} |
|
|
|
function truncateText(text, maxLength) { |
|
if (text.length <= maxLength) return text; |
|
return text.substr(0, maxLength - 2) + '..'; |
|
} |
|
|
|
|