File size: 3,686 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import * as tp from '../etc/types'
import * as d3 from 'd3'
import 'd3-array'
import * as R from 'ramda'
import {SpacyInfo} from '../etc/SpacyInfo'
import {initZero} from '../etc/xramda'

// If value is not a string, don't try to make lowercase
const makeStringLower = R.ifElse(R.is(String), R.toLower, R.identity)

function argMax(array:number[]) {
  return [].map.call(array, (x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1];
}


export class FaissSearchResultWrapper {
    data: tp.FaissSearchResults[]

    options = {
        showNext: false
    }

    constructor(data: tp.FaissSearchResults[], showNext=false) {
        this.data = data
        this.options.showNext = showNext
    }

    get matchAtt() {
        return this.showNext() ? "matched_att_plus_1" : "matched_att"
    }

    get matchIdx() {
        return this.showNext() ? "next_index" : "index"
    }

    /**
     * Add position info interpretable by the histogram
     * 
     * @param countObj Represents the inforrmation to be displayed by the histogram
     */
    countPosInfo() {
        const attOffsets = this.data.map((d,i) => +d[this.matchAtt].out.offset_to_max)

        const ctObj = {
            offset: initZero(attOffsets)
        }

        attOffsets.forEach(v => {
            Object.keys(ctObj).forEach((k) => {
                ctObj[k][v] += 1
            })
        })

        return ctObj
    }

    countMaxAttKeys(indexOffset=0) {
        // The keys in the below object dictate what we count
        const countObj = {
            pos: initZero(SpacyInfo.TotalMetaOptions.pos),
            dep: initZero(SpacyInfo.TotalMetaOptions.dep),
            is_ent: initZero(SpacyInfo.TotalMetaOptions.is_ent),
        }

        // Confusing: Show MATCHED WORD attentions, but NEXT WORD distribution
        const getMaxToken = (d: tp.FaissSearchResults) => d.tokens[argMax(d.matched_att.out.att)]

        this.data.forEach((d, i) => {
            const maxMatch = getMaxToken(d)

            Object.keys(countObj).forEach(k => {
                const val = makeStringLower(String(maxMatch[k]))
                countObj[k][val] += 1;
            })
        })

        const newCountObj = Object.assign(countObj, this.countPosInfo())
        return newCountObj
    }

    countMatchedKeys(indexOffset=0) {
        // The keys in the below object dictate what we count
        const countObj = {
            pos: initZero(SpacyInfo.TotalMetaOptions.pos),
            dep: initZero(SpacyInfo.TotalMetaOptions.dep),
            is_ent: initZero(SpacyInfo.TotalMetaOptions.is_ent),
        }

        this.data.forEach(d => {
        // Confusing: Show MATCHED WORD attentions, but NEXT WORD distribution
            const match = d.tokens[d[this.matchIdx] + indexOffset]

            Object.keys(countObj).forEach(k => {
                const val = makeStringLower(String(match[k]))
                countObj[k][val] += 1;
            })
        })

        return countObj
    }

    getMatchedHistogram(indexOffset=0) {
        const totalHist = this.countMatchedKeys(indexOffset)
        const filterZeros = (val, key) => val != 0;
        const nonZero = R.map(R.pickBy(filterZeros), totalHist)

        return nonZero
    }

    getMaxAttHistogram() {
        // const totalHist = this.countPosInfo()
        const newHist = this.countMaxAttKeys()
        const filterZeros = (val, key) => val != 0;
        const nonZero = R.map(R.pickBy(filterZeros), newHist)

        return nonZero
    }

    showNext(): boolean
    showNext(v:boolean): this
    showNext(v?) {
        if (v == null) return this.options.showNext

        this.options.showNext = v
        return this
    }
}