import * as tp from "./etc/types" import * as x_ from "./etc/_Tools" import * as _ from "lodash" import * as R from 'ramda' import { URLHandler } from "./etc/URLHandler"; const falsey = val => (new Set(['false', 0, "no", false, null, ""])).has(val) const truthy = val => !falsey(val) const toNumber = x => +x; type InspectorOptions = "context" | "embeddings" | null // Must be optional params for initializations interface URLParameters { sentence?: string model?: string modelKind?: string layer?: number heads?: number[] threshold?: number tokenInd?: number | 'null' tokenSide?: tp.SideOptions maskInds?: number[] hideClsSep?: boolean } export class UIConfig { private _conf: URLParameters = {} private _headSet: Set; attType: "aa" _nHeads: number | null; _nLayers: number | null; private _token: tp.TokenEvent; constructor() { this._nHeads = 12; this._nLayers = null; this.attType = 'aa' this.fromURL() this.toURL(false) } toURL(updateHistory = false) { URLHandler.updateUrl(this._conf, updateHistory) } fromURL() { const params = URLHandler.parameters this._conf = { model: params['model'] || 'bert-base-cased', modelKind: params['modelKind'] || tp.ModelKind.Bidirectional, sentence: params['sentence'] || "The girl ran to a local pub to escape the din of her city.", layer: params['layer'] || 1, heads: this._initHeads(params['heads']), threshold: params['threshold'] || 0.7, tokenInd: params['tokenInd'] || null, tokenSide: params['tokenSide'] || null, maskInds: params['maskInds'] || [9], hideClsSep: truthy(params['hideClsSep']) || true, } this._token = { side: this._conf.tokenSide, ind: this._conf.tokenInd } } private _initHeads(v: number[] | null) { if (v == null || v.length < 1) { this.selectAllHeads() } else { console.log(this.headSet(new Set(v))._conf.heads); } return this.heads() } nHeads(): number nHeads(val: number): this nHeads(val?) { if (val == null) return this._nHeads this._nHeads = val return this } nLayers(): number nLayers(val: number): this nLayers(val?) { if (val == null) return this._nLayers this._nLayers = val return this } toggleSelectAllHeads() { if (this.heads().length == 0) { this.selectAllHeads() } else { this.selectNoHeads() } } selectAllHeads() { this.headSet(new Set(_.range(0, this._nHeads))) } selectNoHeads() { this.headSet(new Set([])) } toggleHead(head: number): tp.Toggled { let out; if (this.headSet().has(head)) { this.headSet().delete(head); out = tp.Toggled.REMOVED } else { this.headSet().add(head); out = tp.Toggled.ADDED } // Set through setter function to ensure url is updated this.headSet(this.headSet()); // I hate mutable datastructures... This is confusing. return out } toggleToken(e: tp.TokenEvent): this { const picker = R.pick(['ind', 'side']) const compareEvent = picker(e) const compareToken = picker(this.token()) if (R.equals(compareToken, compareEvent)) { this.rmToken(); } else { this.token(e); } return this; } token(): tp.TokenEvent; token(val: tp.TokenEvent): this; token(val?: tp.TokenEvent) { if (val == null) return this._token this._token = val; this._conf.tokenInd = val.ind; this._conf.tokenSide = val.side; this.toURL(); return this } hasToken() { const conf = this._conf const actuallyNull = ((conf.tokenInd == null) && (conf.tokenSide == null)) const strNull = (conf.tokenInd == "null") return (!actuallyNull) && (!strNull) } rmToken() { this.token({ ind: null, side: null }); return this } sentence(): string; sentence(val: string): this; sentence(val?) { if (val == null) return this._conf.sentence this._conf.sentence = val this.toURL(true) return this } threshold(): number; threshold(val: number): this; threshold(val?) { if (val == null) return this._conf.threshold; this._conf.threshold = val; this.toURL(); return this; } heads(): number[] { return this._conf.heads } layer(): number layer(val: number): this layer(val?) { if (val == null) return this._conf.layer this._conf.layer = val; this.toURL(); return this } headSet(): Set; headSet(val: Set): this; headSet(val?) { if (val == null) { return this._headSet; } this._headSet = val; this._conf.heads = x_.set2SortedArray(this._headSet) this.toURL(); return this } maskInds(): number[]; maskInds(val: number[]): this; maskInds(val?) { if (val == null) return this._conf.maskInds; this._conf.maskInds = val; this.toURL(); return this; } hideClsSep(): boolean; hideClsSep(val: boolean): this; hideClsSep(val?) { if (val == null) return this._conf.hideClsSep; this._conf.hideClsSep = truthy(val); this.toURL(); return this; } model(): string; model(val: string): this; model(val?) { if (val == null) return this._conf.model this._conf.model = val this.toURL(); return this } modelKind(): string; modelKind(val: string): this; modelKind(val?) { if (val == null) return this._conf.modelKind this._conf.modelKind = val this.toURL(); return this } /** * Return the offset needed for the modelKind in the configuration */ get offset() { switch (this.modelKind()) { case tp.ModelKind.Bidirectional: { return 0 } case tp.ModelKind.Autoregressive: { return 0 } default: { return 0 } } } get showNext() { return this.modelKind() == tp.ModelKind.Autoregressive ? true : false } get matchHistogramDescription() { return this.modelKind() == tp.ModelKind.Autoregressive ? "Next" : "Matched" } }