makenmtviz / plotsjs_bertviz.js
Gabriela Nicole Gonzalez Saez
Add files
056bbdc
raw
history blame
No virus
12.3 kB
async () => {
// set testFn() function on globalThis, so you html onlclick can access it
globalThis.testFn = () => {
document.getElementById('demo').innerHTML = "Hello-bertviz?"
};
// await import * as mod from "/my-module.js";
const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm");
const $ = await import("https://cdn.jsdelivr.net/npm/jquery@3.7.1/dist/jquery.min.js");
globalThis.$ = $;
// const $ = await import("https://cdn.jsdelivr.net/npm/jquery@2/+esm");
// import $ from "jquery";
// import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7/+esm";
// await import("https://cdn.jsdelivr.net/npm/jquery@2/+esm");
// export for others scripts to use
// window.$ = window.jQuery = jQuery;
// const d3 = await import("https://cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min");
// const $ = await import("https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min");
globalThis.d3Fn = () => {
d3.select('#viz').append('svg')
.append('rect')
.attr('width', 50)
.attr('height', 50)
.attr('fill', 'black')
.on('mouseover', function(){d3.select(this).attr('fill', 'red')})
.on('mouseout', function(){d3.select(this).attr('fill', 'black')});
};
//
globalThis.testFn_out = (val,model) => {
// document.getElementById('demo').innerHTML = val
console.log(val);
// globalThis.d3Fn();
return([val,model]);
};
globalThis.testFn_out_json = (data) => {
console.log(data);
var $ = jQuery;
console.log($('#viz'));
attViz(data);
return(['string', {}])
};
function attViz(PYTHON_PARAMS) {
var $ = jQuery;
const params = PYTHON_PARAMS; // HACK: PYTHON_PARAMS is a template marker that is replaced by actual params.
const TEXT_SIZE = 15;
const BOXWIDTH = 110;
const BOXHEIGHT = 22.5;
const MATRIX_WIDTH = 115;
const CHECKBOX_SIZE = 20;
const TEXT_TOP = 30;
console.log("d3 version in ffuntions", d3.version)
let headColors;
try {
headColors = d3.scaleOrdinal(d3.schemeCategory10);
} catch (err) {
console.log('Older d3 version')
headColors = d3.scale.category10();
}
let config = {};
// globalThis.
initialize();
renderVis();
function initialize() {
// globalThis.initialize = () => {
console.log("init")
config.attention = params['attention'];
config.filter = params['default_filter'];
config.rootDivId = params['root_div_id'];
config.nLayers = config.attention[config.filter]['attn'].length;
config.nHeads = config.attention[config.filter]['attn'][0].length;
config.layers = params['include_layers']
if (params['heads']) {
config.headVis = new Array(config.nHeads).fill(false);
params['heads'].forEach(x => config.headVis[x] = true);
} else {
config.headVis = new Array(config.nHeads).fill(true);
}
config.initialTextLength = config.attention[config.filter].right_text.length;
config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));
config.layer = config.layers[config.layer_seq]
// '#' + temp1.root_div_id+ ' #layer'
$('#' + config.rootDivId+ ' #layer').empty();
let layerEl = $('#' + config.rootDivId+ ' #layer');
console.log(layerEl)
for (const layer of config.layers) {
layerEl.append($("<option />").val(layer).text(layer));
}
layerEl.val(config.layer).change();
layerEl.on('change', function (e) {
config.layer = +e.currentTarget.value;
config.layer_seq = config.layers.findIndex(layer => config.layer === layer);
renderVis();
});
$('#'+config.rootDivId+' #filter').on('change', function (e) {
// $(`#${config.rootDivId} #filter`).on('change', function (e) {
config.filter = e.currentTarget.value;
renderVis();
});
}
function renderVis() {
// Load parameters
const attnData = config.attention[config.filter];
const leftText = attnData.left_text;
const rightText = attnData.right_text;
// Select attention for given layer
const layerAttention = attnData.attn[config.layer_seq];
// Clear vis
$('#'+config.rootDivId+' #vis').empty();
// Determine size of visualization
const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;
const svg = d3.select('#'+ config.rootDivId +' #vis')
.append('svg')
.attr("width", "100%")
.attr("height", height + "px");
// Display tokens on left and right side of visualization
renderText(svg, leftText, true, layerAttention, 0);
renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);
// Render attention arcs
renderAttention(svg, layerAttention);
// Draw squares at top of visualization, one for each head
drawCheckboxes(0, svg, layerAttention);
}
function renderText(svg, text, isLeft, attention, leftPos) {
const textContainer = svg.append("svg:g")
.attr("id", isLeft ? "left" : "right");
// Add attention highlights superimposed over words
textContainer.append("g")
.classed("attentionBoxes", true)
.selectAll("g")
.data(attention)
.enter()
.append("g")
.attr("head-index", (d, i) => i)
.selectAll("rect")
.data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights
.enter()
.append("rect")
.attr("x", function () {
var headIndex = +this.parentNode.getAttribute("head-index");
return leftPos + boxOffsets(headIndex);
})
.attr("y", (+1) * BOXHEIGHT)
.attr("width", BOXWIDTH / activeHeads())
.attr("height", BOXHEIGHT)
.attr("fill", function () {
return headColors(+this.parentNode.getAttribute("head-index"))
})
.style("opacity", 0.0);
const tokenContainer = textContainer.append("g").selectAll("g")
.data(text)
.enter()
.append("g");
// Add gray background that appears when hovering over text
tokenContainer.append("rect")
.classed("background", true)
.style("opacity", 0.0)
.attr("fill", "lightgray")
.attr("x", leftPos)
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
.attr("width", BOXWIDTH)
.attr("height", BOXHEIGHT);
// Add token text
const textEl = tokenContainer.append("text")
.text(d => d)
.attr("font-size", TEXT_SIZE + "px")
.style("cursor", "default")
.style("-webkit-user-select", "none")
.attr("x", leftPos)
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT);
if (isLeft) {
textEl.style("text-anchor", "end")
.attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE)
.attr("dy", TEXT_SIZE);
} else {
textEl.style("text-anchor", "start")
.attr("dx", +0.5 * TEXT_SIZE)
.attr("dy", TEXT_SIZE);
}
tokenContainer.on("mouseover", function (d, index) {
// Show gray background for moused-over token
textContainer.selectAll(".background")
.style("opacity", (d, i) => i === index ? 1.0 : 0.0)
// Reset visibility attribute for any previously highlighted attention arcs
svg.select("#attention")
.selectAll("line[visibility='visible']")
.attr("visibility", null)
// Hide group containing attention arcs
svg.select("#attention").attr("visibility", "hidden");
// Set to visible appropriate attention arcs to be highlighted
if (isLeft) {
svg.select("#attention").selectAll("line[left-token-index='" + index + "']").attr("visibility", "visible");
} else {
svg.select("#attention").selectAll("line[right-token-index='" + index + "']").attr("visibility", "visible");
}
// Update color boxes superimposed over tokens
const id = isLeft ? "right" : "left";
const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;
svg.select("#" + id)
.selectAll(".attentionBoxes")
.selectAll("g")
.attr("head-index", (d, i) => i)
.selectAll("rect")
.attr("x", function () {
const headIndex = +this.parentNode.getAttribute("head-index");
return leftPos + boxOffsets(headIndex);
})
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
.attr("width", BOXWIDTH / activeHeads())
.attr("height", BOXHEIGHT)
.style("opacity", function (d) {
const headIndex = +this.parentNode.getAttribute("head-index");
if (config.headVis[headIndex])
if (d) {
return d[index];
} else {
return 0.0;
}
else
return 0.0;
});
});
textContainer.on("mouseleave", function () {
// Unhighlight selected token
d3.select(this).selectAll(".background")
.style("opacity", 0.0);
// Reset visibility attributes for previously selected lines
svg.select("#attention")
.selectAll("line[visibility='visible']")
.attr("visibility", null) ;
svg.select("#attention").attr("visibility", "visible");
// Reset highlights superimposed over tokens
svg.selectAll(".attentionBoxes")
.selectAll("g")
.selectAll("rect")
.style("opacity", 0.0);
});
}
function renderAttention(svg, attention) {
// Remove previous dom elements
svg.select("#attention").remove();
// Add new elements
svg.append("g")
.attr("id", "attention") // Container for all attention arcs
.selectAll(".headAttention")
.data(attention)
.enter()
.append("g")
.classed("headAttention", true) // Group attention arcs by head
.attr("head-index", (d, i) => i)
.selectAll(".tokenAttention")
.data(d => d)
.enter()
.append("g")
.classed("tokenAttention", true) // Group attention arcs by left token
.attr("left-token-index", (d, i) => i)
.selectAll("line")
.data(d => d)
.enter()
.append("line")
.attr("x1", BOXWIDTH)
.attr("y1", function () {
const leftTokenIndex = +this.parentNode.getAttribute("left-token-index")
return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)
})
.attr("x2", BOXWIDTH + MATRIX_WIDTH)
.attr("y2", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))
.attr("stroke-width", 2)
.attr("stroke", function () {
const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
return headColors(headIndex)
})
.attr("left-token-index", function () {
return +this.parentNode.getAttribute("left-token-index")
})
.attr("right-token-index", (d, i) => i)
;
updateAttention(svg)
}
function updateAttention(svg) {
svg.select("#attention")
.selectAll("line")
.attr("stroke-opacity", function (d) {
const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
// If head is selected
if (config.headVis[headIndex]) {
// Set opacity to attention weight divided by number of active heads
return d / activeHeads()
} else {
return 0.0;
}
})
}
function boxOffsets(i) {
const numHeadsAbove = config.headVis.reduce(
function (acc, val, cur) {
return val && cur < i ? acc + 1 : acc;
}, 0);
return numHeadsAbove * (BOXWIDTH / activeHeads());
}
function activeHeads() {
return config.headVis.reduce(function (acc, val) {
return val ? acc + 1 : acc;
}, 0);
}
function drawCheckboxes(top, svg) {
const checkboxContainer = svg.append("g");
const checkbox = checkboxContainer.selectAll("rect")
.data(config.headVis)
.enter()
.append("rect")
.attr("fill", (d, i) => headColors(i))
.attr("x", (d, i) => i * CHECKBOX_SIZE)
.attr("y", top)
.attr("width", CHECKBOX_SIZE)
.attr("height", CHECKBOX_SIZE);
function updateCheckboxes() {
checkboxContainer.selectAll("rect")
.data(config.headVis)
.attr("fill", (d, i) => d ? headColors(i): lighten(headColors(i)));
}
updateCheckboxes();
checkbox.on("click", function (d, i) {
if (config.headVis[i] && activeHeads() === 1) return;
config.headVis[i] = !config.headVis[i];
updateCheckboxes();
updateAttention(svg);
});
checkbox.on("dblclick", function (d, i) {
// If we double click on the only active head then reset
if (config.headVis[i] && activeHeads() === 1) {
config.headVis = new Array(config.nHeads).fill(true);
} else {
config.headVis = new Array(config.nHeads).fill(false);
config.headVis[i] = true;
}
updateCheckboxes();
updateAttention(svg);
});
}
function lighten(color) {
const c = d3.hsl(color);
const increment = (1 - c.l) * 0.6;
c.l += increment;
c.s -= increment;
return c;
}
function transpose(mat) {
return mat[0].map(function (col, i) {
return mat.map(function (row) {
return row[i];
});
});
}
}
// );
}