var state = { dataset_size: 15000, threshold: .8, label: 8 } var sel = d3.select('.accuracy-v-privacy-class').html('') .at({role: 'graphics-document', 'aria-label': `Line chart showing that high accuracy models can still perform poorly on some digit classes.`}) async function loadData(){ var rawData = await util.getFile(`cns-cache/grid_${state.dataset_size}trainpoints_test_labels.csv`) rawData.forEach(d => { delete d[''] d.i = +d.i d.label = +d.label }) var aVal2Meta = {} var metadata = await util.getFile('cns-cache/model_grid_test_accuracy.json') metadata .filter(d => d.dataset_size == state.dataset_size) .forEach(d => aVal2Meta['aVal_' + d.aVal] = d) var allCols = d3.keys(rawData[0]) .filter(d => d.includes('aVal')) .map(key => { var {epsilon, aVal} = aVal2Meta[key] return {key, epsilon, aVal} }) var byDigit = d3.nestBy(rawData, d => d.label) byDigit.forEach(d => { d.label = +d.key }) byDigit.forEach(digitClass => { digitClass.cols = allCols.map(({key, epsilon}, colIndex) => { return { key, colIndex, epsilon, digitClass, label: digitClass.label, accuracy: d3.mean(digitClass, d => d[key] > state.threshold) } }) }) var data = _.flatten(byDigit.map(d => d.cols)) .filter(d => util.epsilonExtent[1] <= d.epsilon && d.epsilon <= util.epsilonExtent[0]) var byLabel = d3.nestBy(data, d => d.label) byLabel.forEach((d, i) => { d.label = d.key }) return {data, byLabel} } async function initChart(){ var {data, byLabel} = await loadData() var c = d3.conventions({ sel: sel.append('div'), height: 400, margin: {bottom: 75, top: 5}, layers: 'ds', }) c.x = d3.scaleLog().domain(util.epsilonExtent).range(c.x.range()) c.xAxis = d3.axisBottom(c.x).tickFormat(d => { var rv = d + '' if (rv.split('').filter(d => d !=0 && d != '.')[0] == 1) return rv }) c.yAxis.tickFormat(d => d3.format('.0%')(d))//.ticks(8) d3.drawAxis(c) util.addAxisLabel(c, 'Higher Privacy →', '') util.ggPlotBg(c, false) c.layers[0].append('div') .st({fontSize: 12, color: '#555', width: 120*2, textAlign: 'center', lineHeight: '1.3em', verticalAlign: 'top'}) .translate([c.width/2 - 120, c.height + 45]) .html('in ε') var line = d3.line().x(d => c.x(d.epsilon)).y(d => c.y(d.accuracy)) var lineSel = c.svg.append('g').appendMany('path.accuracy-line', byLabel) .at({ d: line, fill: 'none', stroke: '#000', // opacity: 0, }) .on('mousemove', setActiveLabel) var circleSel = c.svg.append('g') .appendMany('g.accuracy-circle', data) .translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) .on('mousemove', setActiveLabel) // .call(d3.attachTooltip) circleSel.append('circle') .at({r: 7, stroke: '#fff'}) circleSel.append('text') .text(d => d.label) .at({textAnchor: 'middle', fontSize: 10, fill: '#fff', dy: '.33em'}) setActiveLabel(state) function setActiveLabel({label}){ lineSel .classed('active', 0) .filter(d => d.label == label) .classed('active', 1) .raise() circleSel .classed('active', 0) .filter(d => d.label == label) .classed('active', 1) .raise() state.label = label } async function updateDatasetSize(){ var newData = await loadData() data = newData.data byLabel = newData.byLabel lineSel.data(byLabel) .transition() .at({d: line}) circleSel.data(data) .transition() .translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) c.svg.select('text.annotation').remove() } function updateThreshold(){ data.forEach(d => { d.accuracy = d3.mean(d.digitClass, e => e[d.key] > state.threshold) }) lineSel.at({d: line}) circleSel.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) c.svg.select('.y .axis-label').text(`Test Points With More Than ${d3.format('.2%')(state.threshold)} Confidence In Label`) c.svg.select('text.annotation').remove() } updateThreshold() return {c, updateDatasetSize, updateThreshold} } async function init(){ sel.append('div.chart-title').text('High accuracy models can still perform poorly on some digit classes') var chart = await initChart() var buttonRowSel = sel.append('div.button-row') .st({height: 50}) var buttonSel = buttonRowSel.append('div') .st({width: 500}) .append('span.chart-title').text('Training points') .parent() .append('div').st({display: 'inline-block', width: 300, marginLeft: 10}) .append('div.digit-button-container.dataset_size') .appendMany('div.button', [2000, 3750, 7500, 15000, 30000, 60000]) .text(d3.format(',')) .classed('active', d => d == state.dataset_size) .on('click', d => { buttonSel.classed('active', e => e == d) state.dataset_size = d chart.updateDatasetSize() }) buttonRowSel.append('div.conf-slider') .append('span.chart-title').text('Confidence threshold') .parent() .append('input.slider-native') .at({ type: 'range', min: .0001, max: .9999, step: .0001, value: state.threshold, }) .on('input', function(){ state.threshold = this.value chart.updateThreshold() }) function addSliders(){ var width = 140 var height = 30 var color = '#000' var sliders = [ {key: 'threshold', label: 'Confidence threshold', r: [.0001, .9999]}, ] sliders.forEach(d => { d.value = state[d.key] d.xScale = d3.scaleLinear().range([0, width]).domain(d.r).clamp(1) }) d3.select('.conf-slider .slider-container').remove() d3.select('.slider-native').remove() var svgSel = d3.select('.conf-slider').parent() // .st({marginTop: 5, marginBottom: 5}) .appendMany('div.slider-container', sliders) .append('svg').at({width, height}) .append('g').translate([10, 25]) var sliderSel = svgSel .on('click', function(d){ d.value = d.xScale.invert(d3.mouse(this)[0]) renderSliders(d) }) .classed('slider', true) .st({cursor: 'pointer'}) var textSel = sliderSel.append('text.annotation') .at({y: -15, fontWeight: 300, textAnchor: 'middle', x: 180/2}) sliderSel.append('rect') .at({width, height, y: -height/2, fill: 'rgba(0,0,0,0)'}) sliderSel.append('path').at({ d: `M 0 -.5 H ${width}`, stroke: color, strokeWidth: 1 }) var leftPathSel = sliderSel.append('path').at({ d: `M 0 -.5 H ${width}`, stroke: color, strokeWidth: 3 }) var drag = d3.drag() .on('drag', function(d){ var x = d3.mouse(this)[0] d.value = d.xScale.invert(x) renderSliders(d) }) var circleSel = sliderSel.append('circle').call(drag) .at({r: 7, stroke: '#000'}) function renderSliders(d){ if (d) state[d.key] = d.value circleSel.at({cx: d => d.xScale(d.value)}) leftPathSel.at({d: d => `M 0 -.5 H ${d.xScale(d.value)}`}) textSel .at({x: d => d.xScale(d.value)}) .text(d => d3.format('.2%')(d.value)) chart.updateThreshold() } renderSliders() } addSliders() chart.c.svg.append('text.annotation') .translate([505, 212]) .tspans(d3.wordwrap(`8s are correctly predicted with high confidence much more rarely than other digits`, 25), 12) .at({textAnchor: 'end'}) } init()