Spaces:
Running
Running
/** | |
* @license | |
* Copyright 2019 Google LLC. All Rights Reserved. | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* ============================================================================= | |
*/ | |
// https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts | |
class TrieNode { | |
constructor(key) { | |
this.key = key; | |
this.parent = null; | |
this.children = {}; | |
this.end = false; | |
} | |
getWord() { | |
const output = []; | |
let node = this; | |
while (node !== null) { | |
if (node.key !== null) { | |
output.unshift(node.key); | |
} | |
node = node.parent; | |
} | |
return [output, this.score, this.index]; | |
} | |
} | |
class Trie { | |
constructor() { | |
this.root = new TrieNode(null); | |
} | |
insert(word, score, index) { | |
let node = this.root; | |
const symbols = []; | |
for (const symbol of word) { | |
symbols.push(symbol); | |
} | |
for (let i = 0; i < symbols.length; i++) { | |
if (!node.children[symbols[i]]) { | |
node.children[symbols[i]] = new TrieNode(symbols[i]); | |
node.children[symbols[i]].parent = node; | |
} | |
node = node.children[symbols[i]]; | |
if (i === symbols.length - 1) { | |
node.end = true; | |
node.score = score; | |
node.index = index; | |
} | |
} | |
} | |
find(ss) { | |
let node = this.root; | |
let iter = 0; | |
while (iter < ss.length && node != null) { | |
node = node.children[ss[iter]]; | |
iter++; | |
} | |
return node; | |
} | |
} | |
const bert = { | |
loadTokenizer: async () => { | |
const tokenizer = new BertTokenizer(); | |
await tokenizer.load(); | |
return tokenizer; | |
} | |
}; | |
class BertTokenizer { | |
constructor() { | |
this.separator = '\u2581'; | |
this.UNK_INDEX = 100; | |
} | |
async load() { | |
this.vocab = await this.loadVocab(); | |
this.trie = new Trie(); | |
// Actual tokens start at 999. | |
for (let i = 999; i < this.vocab.length; i++) { | |
const word = this.vocab[i]; | |
this.trie.insert(word, 1, i); | |
} | |
this.token2Id = {} | |
this.vocab.forEach((d, i) => { | |
this.token2Id[d] = i | |
}) | |
this.decode = a => a.map(d => this.vocab[d].replace('▁', ' ')).join('') | |
// Adds [CLS] and [SEP] | |
this.tokenizeCLS = str => [101, ...this.tokenize(str), 102] | |
} | |
async loadVocab() { | |
if (!window.bertProcessedVocab){ | |
window.bertProcessedVocab = await (await fetch('data/processed_vocab.json')).json() | |
} | |
return window.bertProcessedVocab | |
} | |
processInput(text) { | |
const words = text.split(' '); | |
return words.map(word => { | |
if (word !== '[CLS]' && word !== '[SEP]') { | |
return this.separator + word.toLowerCase().normalize('NFKC'); | |
} | |
return word; | |
}); | |
} | |
tokenize(text) { | |
// Source: | |
// https://github.com/google-research/bert/blob/88a817c37f788702a363ff935fd173b6dc6ac0d6/tokenization.py#L311 | |
let outputTokens = []; | |
const words = this.processInput(text); | |
for (let i = 0; i < words.length; i++) { | |
const chars = []; | |
for (const symbol of words[i]) { | |
chars.push(symbol); | |
} | |
let isUnknown = false; | |
let start = 0; | |
const subTokens = []; | |
const charsLength = chars.length; | |
while (start < charsLength) { | |
let end = charsLength; | |
let currIndex; | |
while (start < end) { | |
let substr = chars.slice(start, end).join(''); | |
const match = this.trie.find(substr); | |
if (match != null && match.end) { | |
currIndex = match.getWord()[2]; | |
break; | |
} | |
end = end - 1; | |
} | |
if (currIndex == null) { | |
isUnknown = true; | |
break; | |
} | |
subTokens.push(currIndex); | |
start = end; | |
} | |
if (isUnknown) { | |
outputTokens.push(this.UNK_INDEX); | |
} else { | |
outputTokens = outputTokens.concat(subTokens); | |
} | |
} | |
return outputTokens; | |
} | |
} |