import * as x_ from '../etc/_Tools' import * as _ from 'lodash' import * as tp from '../etc/types' import * as R from 'ramda' /** * The original tokens, and the indexes that need to be masked */ const emptyFullResponse: tp.FullSingleTokenInfo[] = [{ text: '[SEP]', topk_words: [], topk_probs: [] }] export class TokenDisplay { tokenData:tp.FullSingleTokenInfo[] maskInds:number[] constructor(tokens=emptyFullResponse, maskInds=[]){ this.tokenData = tokens; this.maskInds = maskInds; } /** * Push idx to the mask idx list in order from smallest to largest */ mask(val) { const currInd = _.indexOf(this.maskInds, val) if (currInd == -1) { x_.orderedInsert_(this.maskInds, val) } else { console.log(`${val} already in maskInds!`); console.log(this.maskInds); } } toggle(val) { const currInd = _.indexOf(this.maskInds, val) if (currInd == -1) { console.log(`Masking ${val}`); this.mask(val) } else { console.log(`Unmasking ${val}`); this.unmask(val) } } unmask(val) { _.pull(this.maskInds, val); } resetMask() { this.maskInds = []; } length() { return this.tokenData.length; } concat(other: TokenDisplay) { const newTokens = _.concat(this.tokenData, other.tokenData); const newMask = _.concat(this.maskInds, other.maskInds.map(x => x + this.length())); return new TokenDisplay(newTokens, newMask); } } export class TokenWrapper { a: TokenDisplay constructor(r:tp.AttentionResponse){ this.updateFromResponse(r); } updateFromResponse(r:tp.AttentionResponse) { const tokensA = r.aa.left; this.updateFromComponents(tokensA, []) } updateFromComponents(a:tp.FullSingleTokenInfo[], maskA:number[]){ this.a = new TokenDisplay(a, maskA) } updateTokens(r: tp.AttentionResponse) { const desiredKeys = ['contexts', 'embeddings', 'topk_probs', 'topk_words'] const newTokens = r.aa.left.map(v => R.pick(desiredKeys, v)) const pairs = R.zip(this.a.tokenData, newTokens) pairs.forEach((d, i) => { Object.keys(d[1]).map(k => { d[0][k] = d[1][k] }) }) } /** * Mask the appropriate sentence at the index indicated */ mask(sID:tp.TokenOptions, idx:number){ this[sID].mask(idx) const opts = ["a", "b"] const Na = this.a.length(); } } export function sideToLetter(side:tp.SideOptions, atype:tp.SentenceOptions){ // const atype = conf.attType; if (atype == "all") { return "all" } const out = side == "left" ? atype[0] : atype[1] // No type checking? return out }