Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,904 Bytes
63858e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import * as x_ from '../etc/_Tools'
import * as _ from 'lodash'
import * as tp from '../etc/types'
import * as R from 'ramda'
/**
* The original tokens, and the indexes that need to be masked
*/
const emptyFullResponse: tp.FullSingleTokenInfo[] = [{
text: '[SEP]',
topk_words: [],
topk_probs: []
}]
export class TokenDisplay {
tokenData:tp.FullSingleTokenInfo[]
maskInds:number[]
constructor(tokens=emptyFullResponse, maskInds=[]){
this.tokenData = tokens;
this.maskInds = maskInds;
}
/**
* Push idx to the mask idx list in order from smallest to largest
*/
mask(val) {
const currInd = _.indexOf(this.maskInds, val)
if (currInd == -1) {
x_.orderedInsert_(this.maskInds, val)
}
else {
console.log(`${val} already in maskInds!`);
console.log(this.maskInds);
}
}
toggle(val) {
const currInd = _.indexOf(this.maskInds, val)
if (currInd == -1) {
console.log(`Masking ${val}`);
this.mask(val)
}
else {
console.log(`Unmasking ${val}`);
this.unmask(val)
}
}
unmask(val) {
_.pull(this.maskInds, val);
}
resetMask() {
this.maskInds = [];
}
length() {
return this.tokenData.length;
}
concat(other: TokenDisplay) {
const newTokens = _.concat(this.tokenData, other.tokenData);
const newMask = _.concat(this.maskInds, other.maskInds.map(x => x + this.length()));
return new TokenDisplay(newTokens, newMask);
}
}
export class TokenWrapper {
a: TokenDisplay
constructor(r:tp.AttentionResponse){
this.updateFromResponse(r);
}
updateFromResponse(r:tp.AttentionResponse) {
const tokensA = r.aa.left;
this.updateFromComponents(tokensA, [])
}
updateFromComponents(a:tp.FullSingleTokenInfo[], maskA:number[]){
this.a = new TokenDisplay(a, maskA)
}
updateTokens(r: tp.AttentionResponse) {
const desiredKeys = ['contexts', 'embeddings', 'topk_probs', 'topk_words']
const newTokens = r.aa.left.map(v => R.pick(desiredKeys, v))
const pairs = R.zip(this.a.tokenData, newTokens)
pairs.forEach((d, i) => {
Object.keys(d[1]).map(k => {
d[0][k] = d[1][k]
})
})
}
/**
* Mask the appropriate sentence at the index indicated
*/
mask(sID:tp.TokenOptions, idx:number){
this[sID].mask(idx)
const opts = ["a", "b"]
const Na = this.a.length();
}
}
export function sideToLetter(side:tp.SideOptions, atype:tp.SentenceOptions){
// const atype = conf.attType;
if (atype == "all") {
return "all"
}
const out = side == "left" ? atype[0] : atype[1] // No type checking?
return out
}
|