Spaces:
Running
Running
var state = { | |
dataset_size: 15000, | |
threshold: .8, | |
label: 8 | |
} | |
var sel = d3.select('.accuracy-v-privacy-class').html('') | |
.at({role: 'graphics-document', 'aria-label': `Line chart showing that high accuracy models can still perform poorly on some digit classes.`}) | |
async function loadData(){ | |
var rawData = await util.getFile(`cns-cache/grid_${state.dataset_size}trainpoints_test_labels.csv`) | |
rawData.forEach(d => { | |
delete d[''] | |
d.i = +d.i | |
d.label = +d.label | |
}) | |
var aVal2Meta = {} | |
var metadata = await util.getFile('cns-cache/model_grid_test_accuracy.json') | |
metadata | |
.filter(d => d.dataset_size == state.dataset_size) | |
.forEach(d => aVal2Meta['aVal_' + d.aVal] = d) | |
var allCols = d3.keys(rawData[0]) | |
.filter(d => d.includes('aVal')) | |
.map(key => { | |
var {epsilon, aVal} = aVal2Meta[key] | |
return {key, epsilon, aVal} | |
}) | |
var byDigit = d3.nestBy(rawData, d => d.label) | |
byDigit.forEach(d => { | |
d.label = +d.key | |
}) | |
byDigit.forEach(digitClass => { | |
digitClass.cols = allCols.map(({key, epsilon}, colIndex) => { | |
return { | |
key, | |
colIndex, | |
epsilon, | |
digitClass, | |
label: digitClass.label, | |
accuracy: d3.mean(digitClass, d => d[key] > state.threshold) | |
} | |
}) | |
}) | |
var data = _.flatten(byDigit.map(d => d.cols)) | |
.filter(d => util.epsilonExtent[1] <= d.epsilon && d.epsilon <= util.epsilonExtent[0]) | |
var byLabel = d3.nestBy(data, d => d.label) | |
byLabel.forEach((d, i) => { | |
d.label = d.key | |
}) | |
return {data, byLabel} | |
} | |
async function initChart(){ | |
var {data, byLabel} = await loadData() | |
var c = d3.conventions({ | |
sel: sel.append('div'), | |
height: 400, | |
margin: {bottom: 75, top: 5}, | |
layers: 'ds', | |
}) | |
c.x = d3.scaleLog().domain(util.epsilonExtent).range(c.x.range()) | |
c.xAxis = d3.axisBottom(c.x).tickFormat(d => { | |
var rv = d + '' | |
if (rv.split('').filter(d => d !=0 && d != '.')[0] == 1) return rv | |
}) | |
c.yAxis.tickFormat(d => d3.format('.0%')(d))//.ticks(8) | |
d3.drawAxis(c) | |
util.addAxisLabel(c, 'Higher Privacy →', '') | |
util.ggPlotBg(c, false) | |
c.layers[0].append('div') | |
.st({fontSize: 12, color: '#555', width: 120*2, textAlign: 'center', lineHeight: '1.3em', verticalAlign: 'top'}) | |
.translate([c.width/2 - 120, c.height + 45]) | |
.html('in ε') | |
var line = d3.line().x(d => c.x(d.epsilon)).y(d => c.y(d.accuracy)) | |
var lineSel = c.svg.append('g').appendMany('path.accuracy-line', byLabel) | |
.at({ | |
d: line, | |
fill: 'none', | |
stroke: '#000', | |
// opacity: 0, | |
}) | |
.on('mousemove', setActiveLabel) | |
var circleSel = c.svg.append('g') | |
.appendMany('g.accuracy-circle', data) | |
.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) | |
.on('mousemove', setActiveLabel) | |
// .call(d3.attachTooltip) | |
circleSel.append('circle') | |
.at({r: 7, stroke: '#fff'}) | |
circleSel.append('text') | |
.text(d => d.label) | |
.at({textAnchor: 'middle', fontSize: 10, fill: '#fff', dy: '.33em'}) | |
setActiveLabel(state) | |
function setActiveLabel({label}){ | |
lineSel | |
.classed('active', 0) | |
.filter(d => d.label == label) | |
.classed('active', 1) | |
.raise() | |
circleSel | |
.classed('active', 0) | |
.filter(d => d.label == label) | |
.classed('active', 1) | |
.raise() | |
state.label = label | |
} | |
async function updateDatasetSize(){ | |
var newData = await loadData() | |
data = newData.data | |
byLabel = newData.byLabel | |
lineSel.data(byLabel) | |
.transition() | |
.at({d: line}) | |
circleSel.data(data) | |
.transition() | |
.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) | |
c.svg.select('text.annotation').remove() | |
} | |
function updateThreshold(){ | |
data.forEach(d => { | |
d.accuracy = d3.mean(d.digitClass, e => e[d.key] > state.threshold) | |
}) | |
lineSel.at({d: line}) | |
circleSel.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) | |
c.svg.select('.y .axis-label').text(`Test Points With More Than ${d3.format('.2%')(state.threshold)} Confidence In Label`) | |
c.svg.select('text.annotation').remove() | |
} | |
updateThreshold() | |
return {c, updateDatasetSize, updateThreshold} | |
} | |
async function init(){ | |
sel.append('div.chart-title').text('High accuracy models can still perform poorly on some digit classes') | |
var chart = await initChart() | |
var buttonRowSel = sel.append('div.button-row') | |
.st({height: 50}) | |
var buttonSel = buttonRowSel.append('div') | |
.st({width: 500}) | |
.append('span.chart-title').text('Training points') | |
.parent() | |
.append('div').st({display: 'inline-block', width: 300, marginLeft: 10}) | |
.append('div.digit-button-container.dataset_size') | |
.appendMany('div.button', [2000, 3750, 7500, 15000, 30000, 60000]) | |
.text(d3.format(',')) | |
.classed('active', d => d == state.dataset_size) | |
.on('click', d => { | |
buttonSel.classed('active', e => e == d) | |
state.dataset_size = d | |
chart.updateDatasetSize() | |
}) | |
buttonRowSel.append('div.conf-slider') | |
.append('span.chart-title').text('Confidence threshold') | |
.parent() | |
.append('input.slider-native') | |
.at({ | |
type: 'range', | |
min: .0001, | |
max: .9999, | |
step: .0001, | |
value: state.threshold, | |
}) | |
.on('input', function(){ | |
state.threshold = this.value | |
chart.updateThreshold() | |
}) | |
function addSliders(){ | |
var width = 140 | |
var height = 30 | |
var color = '#000' | |
var sliders = [ | |
{key: 'threshold', label: 'Confidence threshold', r: [.0001, .9999]}, | |
] | |
sliders.forEach(d => { | |
d.value = state[d.key] | |
d.xScale = d3.scaleLinear().range([0, width]).domain(d.r).clamp(1) | |
}) | |
d3.select('.conf-slider .slider-container').remove() | |
d3.select('.slider-native').remove() | |
var svgSel = d3.select('.conf-slider').parent() | |
// .st({marginTop: 5, marginBottom: 5}) | |
.appendMany('div.slider-container', sliders) | |
.append('svg').at({width, height}) | |
.append('g').translate([10, 25]) | |
var sliderSel = svgSel | |
.on('click', function(d){ | |
d.value = d.xScale.invert(d3.mouse(this)[0]) | |
renderSliders(d) | |
}) | |
.classed('slider', true) | |
.st({cursor: 'pointer'}) | |
var textSel = sliderSel.append('text.annotation') | |
.at({y: -15, fontWeight: 300, textAnchor: 'middle', x: 180/2}) | |
sliderSel.append('rect') | |
.at({width, height, y: -height/2, fill: 'rgba(0,0,0,0)'}) | |
sliderSel.append('path').at({ | |
d: `M 0 -.5 H ${width}`, | |
stroke: color, | |
strokeWidth: 1 | |
}) | |
var leftPathSel = sliderSel.append('path').at({ | |
d: `M 0 -.5 H ${width}`, | |
stroke: color, | |
strokeWidth: 3 | |
}) | |
var drag = d3.drag() | |
.on('drag', function(d){ | |
var x = d3.mouse(this)[0] | |
d.value = d.xScale.invert(x) | |
renderSliders(d) | |
}) | |
var circleSel = sliderSel.append('circle').call(drag) | |
.at({r: 7, stroke: '#000'}) | |
function renderSliders(d){ | |
if (d) state[d.key] = d.value | |
circleSel.at({cx: d => d.xScale(d.value)}) | |
leftPathSel.at({d: d => `M 0 -.5 H ${d.xScale(d.value)}`}) | |
textSel | |
.at({x: d => d.xScale(d.value)}) | |
.text(d => d3.format('.2%')(d.value)) | |
chart.updateThreshold() | |
} | |
renderSliders() | |
} | |
addSliders() | |
chart.c.svg.append('text.annotation') | |
.translate([505, 212]) | |
.tspans(d3.wordwrap(`8s are correctly predicted with high confidence much more rarely than other digits`, 25), 12) | |
.at({textAnchor: 'end'}) | |
} | |
init() | |