import * as _ from 'lodash' import * as x_ from '../etc/_Tools' import * as tp from '../etc/types' import * as tf from '@tensorflow/tfjs' /** * Notes: * * - Also encapsulate the CLS/SEP info vs. no CLS/SEP info * - When layer format changes from list, drop the index into conf.layer */ const bpeTokens = ["[CLS]", "[SEP]", "", "", "<|endoftext|>"] const findBadIndexes = (x: tp.FullSingleTokenInfo[]) => x_.findAllIndexes(x.map(t => t.text), (a) => _.includes(bpeTokens, a)) export function makeFromMetaResponse(r:tp.AttentionResponse, isZeroed){ const key = 'aa' // Change this if backend response changes to be simpler const currPair = r[key] const left = currPair.left const right = currPair.right const leftZero = x_.findAllIndexes(left.map(t => t.text), (a) => _.includes(bpeTokens, a)) const rightZero = x_.findAllIndexes(right.map(t => t.text), (a) => _.includes(bpeTokens, a)) return new AttentionWrapper(currPair.att, [leftZero, rightZero], isZeroed) } export class AttentionWrapper { protected _att:number[][][] protected _attTensor:tf.Tensor3D protected _zeroedAttTensor:tf.Tensor3D badToks:[number[], number[]] // Indexes for the CLS and SEP tokens isZeroed: boolean nLayers = 12; nHeads = 12; constructor(att:number[][][], badToks:[number[], number[]]=[[],[]], isZeroed=true){ this.init(att, badToks, isZeroed) } init(att:number[][][], badToks:[number[], number[]]=[[],[]], isZeroed) { this.isZeroed = isZeroed this._att = att; this._zeroedAttTensor = zeroRowCol(tf.tensor3d(att), badToks[0], badToks[1]) this._attTensor = tf.tensor3d(att) // If I put this first, buffer modifications change this too. this.badToks = badToks; } updateFromNormal(r:tp.AttentionResponse, isZeroed){ const key = 'aa' // Change this if backend response changes to be simpler const currPair = r[key] const left = currPair.left const right = currPair.right const leftZero = findBadIndexes(left) const rightZero = findBadIndexes(right) this.init(currPair.att, [leftZero, rightZero], isZeroed) } get attTensor() { const tens = this.isZeroed ? this._zeroedAttTensor : this._attTensor return tens } get att() { return this.attTensor.arraySync() } zeroed(): boolean zeroed(val:boolean): this zeroed(val?) { if (val == null) return this.isZeroed this.isZeroed = val return this } toggleZeroing() { this.zeroed(!this.zeroed()) } protected _byHeads(heads:number[]):tf.Tensor2D { if (heads.length == 0) { return tf.zerosLike(this._byHead(0)) } return (this.attTensor.gather(heads, 0).sum(0)) } protected _byHead(head:number):tf.Tensor2D { return (this.attTensor.gather([head], 0).squeeze([0])) } byHeads(heads:number[]):number[][] { return this._byHeads(heads).arraySync() } byHead(head:number):number[][] { return this._byHead(head).arraySync() } } function zeroRowCol(tens:tf.Tensor3D, rows:number[], cols:number[]):tf.Tensor3D { let outTens = tens.clone() let atb = outTens.bufferSync() _.range(atb.shape[0]).forEach((head) => { _.range(atb.shape[1]).forEach((i) => { // Set rows to 0 if (_.includes(rows, i)) { _.range(atb.shape[2]).forEach((j) => { atb.set(0, head, i, j) }) } // Set cols to 0 _.range(atb.shape[2]).forEach((j) => { if (_.includes(cols, j)) _.range(atb.shape[1]).forEach((i) => { atb.set(0, head, i, j) }) }) }) }) return outTens }