mervenoyan's picture
commit files to HF hub
40559c4
raw
history blame
4.41 kB
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts
class TrieNode {
constructor(key) {
this.key = key;
this.parent = null;
this.children = {};
this.end = false;
}
getWord() {
const output = [];
let node = this;
while (node !== null) {
if (node.key !== null) {
output.unshift(node.key);
}
node = node.parent;
}
return [output, this.score, this.index];
}
}
class Trie {
constructor() {
this.root = new TrieNode(null);
}
insert(word, score, index) {
let node = this.root;
const symbols = [];
for (const symbol of word) {
symbols.push(symbol);
}
for (let i = 0; i < symbols.length; i++) {
if (!node.children[symbols[i]]) {
node.children[symbols[i]] = new TrieNode(symbols[i]);
node.children[symbols[i]].parent = node;
}
node = node.children[symbols[i]];
if (i === symbols.length - 1) {
node.end = true;
node.score = score;
node.index = index;
}
}
}
find(ss) {
let node = this.root;
let iter = 0;
while (iter < ss.length && node != null) {
node = node.children[ss[iter]];
iter++;
}
return node;
}
}
const bert = {
loadTokenizer: async () => {
const tokenizer = new BertTokenizer();
await tokenizer.load();
return tokenizer;
}
};
class BertTokenizer {
constructor() {
this.separator = '\u2581';
this.UNK_INDEX = 100;
}
async load() {
this.vocab = await this.loadVocab();
this.trie = new Trie();
// Actual tokens start at 999.
for (let i = 999; i < this.vocab.length; i++) {
const word = this.vocab[i];
this.trie.insert(word, 1, i);
}
this.token2Id = {}
this.vocab.forEach((d, i) => {
this.token2Id[d] = i
})
this.decode = a => a.map(d => this.vocab[d].replace('▁', ' ')).join('')
// Adds [CLS] and [SEP]
this.tokenizeCLS = str => [101, ...this.tokenize(str), 102]
}
async loadVocab() {
if (!window.bertProcessedVocab){
window.bertProcessedVocab = await (await fetch('data/processed_vocab.json')).json()
}
return window.bertProcessedVocab
}
processInput(text) {
const words = text.split(' ');
return words.map(word => {
if (word !== '[CLS]' && word !== '[SEP]') {
return this.separator + word.toLowerCase().normalize('NFKC');
}
return word;
});
}
tokenize(text) {
// Source:
// https://github.com/google-research/bert/blob/88a817c37f788702a363ff935fd173b6dc6ac0d6/tokenization.py#L311
let outputTokens = [];
const words = this.processInput(text);
for (let i = 0; i < words.length; i++) {
const chars = [];
for (const symbol of words[i]) {
chars.push(symbol);
}
let isUnknown = false;
let start = 0;
const subTokens = [];
const charsLength = chars.length;
while (start < charsLength) {
let end = charsLength;
let currIndex;
while (start < end) {
let substr = chars.slice(start, end).join('');
const match = this.trie.find(substr);
if (match != null && match.end) {
currIndex = match.getWord()[2];
break;
}
end = end - 1;
}
if (currIndex == null) {
isUnknown = true;
break;
}
subTokens.push(currIndex);
start = end;
}
if (isUnknown) {
outputTokens.push(this.UNK_INDEX);
} else {
outputTokens = outputTokens.concat(subTokens);
}
}
return outputTokens;
}
}