File size: 2,904 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import * as x_ from '../etc/_Tools'
import * as _ from 'lodash'
import * as tp from '../etc/types'
import * as R from 'ramda'

/**
 * The original tokens, and the indexes that need to be masked
 */
 const emptyFullResponse: tp.FullSingleTokenInfo[] = [{
     text: '[SEP]',
     topk_words: [],
     topk_probs: []
 }]

export class TokenDisplay  {
    tokenData:tp.FullSingleTokenInfo[]
    maskInds:number[]

    constructor(tokens=emptyFullResponse, maskInds=[]){
        this.tokenData = tokens;
        this.maskInds = maskInds;
    }

    /**
     * Push idx to the mask idx list in order from smallest to largest
     */
    mask(val) {
        const currInd = _.indexOf(this.maskInds, val)
        if (currInd == -1) {
            x_.orderedInsert_(this.maskInds, val)
        }
        else {
            console.log(`${val} already in maskInds!`);
            console.log(this.maskInds);
        }
    }

    toggle(val) {
        const currInd = _.indexOf(this.maskInds, val)
        if (currInd == -1) {
            console.log(`Masking ${val}`);
            this.mask(val)
        }
        else {
            console.log(`Unmasking ${val}`);
            this.unmask(val)
        }
    }

    unmask(val) {
        _.pull(this.maskInds, val);
    }

    resetMask() {
        this.maskInds = [];
    }

    length() {
        return this.tokenData.length;
    }

    concat(other: TokenDisplay) {
        const newTokens = _.concat(this.tokenData, other.tokenData);
        const newMask = _.concat(this.maskInds, other.maskInds.map(x => x + this.length()));
        return new TokenDisplay(newTokens, newMask);
    }
}

export class TokenWrapper {
    a: TokenDisplay

    constructor(r:tp.AttentionResponse){
        this.updateFromResponse(r);
    }

    updateFromResponse(r:tp.AttentionResponse) {
        const tokensA = r.aa.left;
        this.updateFromComponents(tokensA, [])
    }

    updateFromComponents(a:tp.FullSingleTokenInfo[], maskA:number[]){
        this.a = new TokenDisplay(a, maskA)
    }

    updateTokens(r: tp.AttentionResponse) {
        const desiredKeys = ['contexts', 'embeddings', 'topk_probs', 'topk_words']
        const newTokens = r.aa.left.map(v => R.pick(desiredKeys, v))

        const pairs = R.zip(this.a.tokenData, newTokens)

        pairs.forEach((d, i) => {
            Object.keys(d[1]).map(k => {
                d[0][k] = d[1][k]
            })
        })

    }

    /**
     * Mask the appropriate sentence at the index indicated
     */
    mask(sID:tp.TokenOptions, idx:number){
        this[sID].mask(idx)
        const opts = ["a", "b"]
        const Na = this.a.length();
    }
}

export function sideToLetter(side:tp.SideOptions, atype:tp.SentenceOptions){
    // const atype = conf.attType;
    if (atype == "all") {
        return "all"
    }
    const out = side == "left" ? atype[0] : atype[1] // No type checking?
    return out
}