/* Copyright 2021 Google LLC. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ window.initPair = function(pair){ var isMobile = window.innerWidth <= 820 var sel = d3.select('.' + pair.class).html('') .at({role: 'graphics-document', 'aria-label': pair.ariaLabel}) .on('keydown', function(){ sel.classed('changed', 1) if (d3.event.keyCode != 13) return d3.event.preventDefault() // return pair.str0 = '' pair.str1 = '' updateChart() }) if (!sel.node()) return var optionSel = sel.append('div.options') var inputRow = optionSel.append('div.flex-row.flex-row-textarea') var input1Sel = inputRow.append('textarea.input-1') .st({color: util.colors[1]}).at({cols: 30}) input1Sel.node().value = pair.s1.replace('[MASK]', '_') var input0Sel = inputRow.append('textarea.input-0') .st({color: util.colors[0]}).at({cols: 30}) input0Sel.node().value = pair.s0.replace('[MASK]', '_') if (isMobile){ sel.selectAll('textarea').on('change', updateChart) } var countSel = optionSel.append('div') .append('b').text('Number of Tokens') .append('info').text('ⓘ').call(addLockedTooltip) .datum('The scales are set using the top N tokens for each sentence.

"Likelihoods" will show more than N tokens if a top completion for one sentence is unlikely for the other sentence.') .parent().parent() .append('div.flex-row') .appendMany('div.button', [30, 200, 1000, 5000, 99999]) .text(d => d > 5000 ? 'All' : d) .st({textAlign: 'center'}) .on('click', d => { pair.count = d updateChart() }) var typeSel = optionSel.append('div') .append('b').text('Chart Type') .append('info').text('ⓘ').call(addLockedTooltip) .datum('"Likelihoods" shows the logits from both models plotted directly with a shared linear scale.

To better contrast the outputs, "Differences" shows logitA - logitB on the y-axis and mean(logitA, logitB) on the x-axis with separate linear scales.') .parent().parent() .append('div.flex-row') .appendMany('div.button', ['Likelihoods', 'Differences']) .text(d => d) .st({textAlign: 'center'}) .on('click', d => { pair.type = d updateChart() }) var modelSel = optionSel.append('div') .st({display: pair.model == 'BERT' ? 'none' : ''}) .append('b').text('Model') .parent() .append('div.flex-row') .appendMany('div.button', ['BERT', 'Zari']) .text(d => d) .st({textAlign: 'center'}) .on('click', d => { pair.model = d updateChart() }) // TODO add loading spinner var updateSel = optionSel .append('div.flex-row') .append('div.button.update').on('click', updateChart) .text('Update') .st({display: isMobile ? 'none' : ''}) var warningSel = optionSel.append('div.warning') .text('⚠️Some of the text this model was trained on includes harmful stereotypes. This is a tool to uncover these associations—not an endorsement of them.') var resetSel = optionSel.append('div.reset') .html(' Reset') .on('click', () => { pair = JSON.parse(pair.pairStr) pair.pairStr = JSON.stringify(pair) input0Sel.node().value = pair.s0 input1Sel.node().value = pair.s1 updateChart(true) }) if (pair.alts){ d3.select('.' + pair.class + '-alts').html('') .classed('alt-block', 1).st({display: 'block'}) .appendMany('span.p-button-link', pair.alts) .html(d => d.str) .on('click', d => { input0Sel.node().value = d.s0 input1Sel.node().value = d.s1 updateChart() }) } var margin = {bottom: 50, left: 25, top: 5, right: 20} var graphSel = sel.append('div.graph') var totalWidth = graphSel.node().offsetWidth var width = totalWidth - margin.left - margin.right var c = d3.conventions({ sel: graphSel.append('div').st({marginTop: isMobile ? 20 : -5}), width, height: width, margin, layers: 'sdds', }) var nTicks = 4 var tickScale = d3.scaleLinear().range([0, c.width]) c.svg.appendMany('path.bg-tick', d3.range(nTicks + 1)) .at({d: d => `M ${.5 + Math.round(tickScale(d/nTicks))} 0 V ${c.height}`}) c.svg.appendMany('path.bg-tick', d3.range(nTicks + 1)) .at({d: d => `M 0 ${.5 + Math.round(tickScale(d/nTicks))} H ${c.width}`}) var annotationSel = c.layers[1].appendMany('div.annotations', pair.annotations) .translate(d => d.pos) .html(d => d.str) .st({color: d => d.color, width: 250, postion: 'absolute'}) var scatter = window.initScatter(c) updateChart(true) async function updateChart(isFirst){ sel.classed('changed', 0) warningSel.st({opacity: isFirst ? 0 : 1}) resetSel.st({opacity: isFirst ? 0 : 1}) annotationSel.st({opacity: isFirst ? 1 : 0}) countSel.classed('active', d => d == pair.count) typeSel.classed('active', d => d == pair.type) modelSel.classed('active', d => d == pair.model) function getStr(sel){ return sel.node().value.replace('_', '[MASK]') } var modelPath = pair.model == 'Zari' ? 'embed_zari_cda' : 'embed' pair.s0 = input0Sel.node().value.replace('_', '[MASK]') pair.s1 = input1Sel.node().value.replace('_', '[MASK]') updateSel.classed('loading', 1) var vals0 = await post(modelPath, {sentence: pair.s0}) var vals1 = await post(modelPath, {sentence: pair.s1}) updateSel.classed('loading', 0) var allTokens = vals0.map((v0, i) => { return {word: tokenizer.vocab[i], v0, i, v1: vals1[i]} }) allTokens.forEach(d => { d.dif = d.v0 - d.v1 d.meanV = (d.v0 + d.v1) / 2 d.isVisible = false }) _.sortBy(allTokens, d => -d.v1).forEach((d, i) => d.v1i = i) _.sortBy(allTokens, d => -d.v0).forEach((d, i) => d.v0i = i) var topTokens = allTokens.filter(d => d.v0i <= pair.count || d.v1i <= pair.count) var logitExtent = d3.extent(topTokens.map(d => d.v0).concat(topTokens.map(d => d.v1))) var tokens = allTokens .filter(d => logitExtent[0] <= d.v0 && logitExtent[0] <= d.v1) var mag = logitExtent[1] - logitExtent[0] logitExtent = [logitExtent[0] - mag*.002, logitExtent[1] + mag*.002] if (pair.type == 'Differences') tokens = _.sortBy(allTokens, d => -d.meanV).slice(0, pair.count) tokens.forEach(d => { d.isVisible = true }) var maxDif = d3.max(d3.extent(tokens, d => d.dif).map(Math.abs)) var color = palette(-maxDif*.8, maxDif*.8) updateSentenceLabels() if (pair.type == 'Likelihoods'){ drawXY() } else{ drawRotated() } sel.classed('is-xy', pair.type == 'Likelihoods') sel.classed('is-rotate', pair.type != 'Likelihoods') function drawXY(){ c.x.domain(logitExtent) c.y.domain(logitExtent) d3.drawAxis(c) var s = {30: 4, 200: 3, 1000: 3}[pair.count] || 2 var scatterData = allTokens.map(d => { var x = c.x(d.v0) var y = c.y(d.v1) var fill = color(d.dif) var dif = d.dif var word = d.word var show = '' var isVisible = d.isVisible return {x, y, s, dif, fill, word, show, isVisible} }) var textCandidates = _.sortBy(scatterData.filter(d => d.isVisible), d => d.dif) d3.nestBy(textCandidates.slice(0, 1000), d => Math.round(d.y/10)) .forEach(d => d[0].show = 'uf') d3.nestBy(textCandidates.reverse().slice(0, 1000), d => Math.round(d.y/10)) .forEach(d => d[0].show = 'lr') logitExtent.pair = pair scatter.draw(c, scatterData, true) c.svg.selectAppend('text.x-axis-label.xy-only') .translate([c.width/2, c.height + 24]) .text(pair.label0 ? ' __ likelihood, ' + pair.label0 + ' sentence →' : '__ likelihood, sentence two →') .st({fill: util.colors[0]}) .at({textAnchor: 'middle'}) c.svg.selectAppend('g.y-axis-label.xy-only') .translate([c.width + 20, c.height/2]) .selectAppend('text') .text(pair.label1 ? ' __ likelihood, ' + pair.label1 + ' sentence →' : '__ likelihood, sentence one →') .st({fill: util.colors[1]}) .at({textAnchor: 'middle', transform: 'rotate(-90)'}) } function drawRotated(){ c.x.domain(d3.extent(tokens, d => d.meanV)) c.y.domain([maxDif, -maxDif]) d3.drawAxis(c) var scatterData = allTokens.map(d => { var x = c.x(d.meanV) var y = c.y(d.dif) var fill = color(d.dif) var word = d.word var show = '' var isVisible = d.isVisible return {x, y, s: 2, fill, word, show, isVisible} }) scatterData.forEach(d => { d.dx = d.x - c.width/2 d.dy = d.y - c.height/2 }) var textCandidates = _.sortBy(scatterData, d => -d.dx*d.dx - d.dy*d.dy) .filter(d => d.isVisible) .slice(0, 5000) d3.nestBy(textCandidates, d => Math.round(12*Math.atan2(d.dx, d.dy))) .map(d => d[0]) .forEach(d => d.show = (d.dy < 0 ? 'u' : 'l') + (d.dx < 0 ? 'l' : 'r')) scatter.draw(c, scatterData, false) c.svg.selectAppend('text.rotate-only.x-axis-label') .translate([c.width/2, c.height + 24]) .text('__ likelihood, both sentences →') .at({textAnchor: 'middle'}) .st({fill: '#000'}) c.svg.selectAll('g.rotate-only.sent-1,g.rotate-only.sent-1').remove() c.svg.selectAppend('g.rotate-only.sent-1') .translate([c.width + 20, c.height/2]) .append('text') .text(`Higher likelihood, ${pair.label1 ? pair.label1 + ' sentence ' : 'sentence one'} →`) .at({textAnchor: 'start', transform: 'rotate(-90)', x: 20}) .st({fill: util.colors[1]}) c.svg.selectAppend('g.rotate-only.sent-1') .translate([c.width + 20, c.height/2 + 0]) .append('text') .text(`← Higher likelihood, ${pair.label0 ? pair.label0 + ' sentence ' : 'sentence two'}`) .at({textAnchor: 'end', transform: 'rotate(-90)', x: -20}) .st({fill: util.colors[0]}) } } function updateSentenceLabels(){ var t0 = tokenizer.tokenize(pair.s0) var t1 = tokenizer.tokenize(pair.s1) var i = 0 while (t0[i] == t1[i] && i < t0.length) i++ var j = 1 while (t0[t0.length - j] == t1[t1.length - j] && j < t0.length) j++ pair.label0 = tokens2origStr(t0, pair.s0) pair.label1 = tokens2origStr(t1, pair.s1) function tokens2origStr(t, s){ var tokenStr = tokenizer.decode(t.slice(i, -j + 1)).trim() var lowerStr = s.toLowerCase() var startI = lowerStr.indexOf(tokenStr) return s.slice(startI, startI + tokenStr.length) } if ( !pair.label0.length || !pair.label1.length || pair.label0.length > 15 || pair.label1.length > 15){ pair.label0 = '' pair.label1 = '' } // console.log(i, j, pair.label0, pair.label1) } } if (window.init) init()