LLM-fastAPI / preprocess /vocabulary.py
Songyou's picture
add files
f3b11f9
# coding=utf-8
"""
Vocabulary helper class
"""
import re
import numpy as np
class Vocabulary:
"""Stores the tokens and their conversion to one-hot vectors."""
def __init__(self, tokens=None, starting_id=0):
self._tokens = {}
self._current_id = starting_id
if tokens:
for token, idx in tokens.items():
self._add(token, idx)
self._current_id = max(self._current_id, idx + 1)
def __getitem__(self, token_or_id):
return self._tokens[token_or_id]
def add(self, token):
"""Adds a token."""
if not isinstance(token, str):
raise TypeError("Token is not a string")
if token in self:
# raise ValueError("Token already present in the vocabulary")
print(f'=== Token "{token}"already present in the vocabulary')
return
self._add(token, self._current_id)
self._current_id += 1
return self._current_id - 1
def update(self, tokens):
"""Adds many tokens."""
return [self.add(token) for token in tokens]
def __delitem__(self, token_or_id):
other_val = self._tokens[token_or_id]
del self._tokens[other_val]
del self._tokens[token_or_id]
def __contains__(self, token_or_id):
return token_or_id in self._tokens
def __eq__(self, other_vocabulary):
return self._tokens == other_vocabulary._tokens
def __len__(self):
return len(self._tokens) // 2
def encode(self, tokens):
"""Encodes a list of tokens, encoding them in 1-hot encoded vectors."""
ohe_vect = np.zeros(len(tokens), dtype=np.float32)
for i, token in enumerate(tokens):
try:
ohe_vect[i] = self._tokens[token]
except KeyError:
ohe_vect[i] = self._tokens["default_key"]
return ohe_vect
def decode(self, ohe_vect):
"""Decodes a one-hot encoded vector matrix to a list of tokens."""
tokens = []
for ohv in ohe_vect:
try:
tokens.append(self[ohv])
except KeyError:
tokens.append("default_key")
return tokens
def _add(self, token, idx):
if idx not in self._tokens:
self._tokens[token] = idx
self._tokens[idx] = token
else:
raise ValueError("IDX already present in vocabulary")
def tokens(self):
"""Returns the tokens from the vocabulary"""
return [t for t in self._tokens if isinstance(t, str)]
def word2idx(self):
return {k: self._tokens[k] for k in self._tokens if isinstance(k, str)}
class SMILESTokenizer:
"""Deals with the tokenization and untokenization of SMILES."""
REGEXPS = {
"brackets": re.compile(r"(\[[^\]]*\])"),
"2_ring_nums": re.compile(r"(%\d{2})"),
"brcl": re.compile(r"(Br|Cl)")
}
REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"]
def tokenize(self, data, with_begin_and_end=True):
"""Tokenizes a SMILES string."""
def split_by(data, regexps):
if not regexps:
return list(data)
regexp = self.REGEXPS[regexps[0]]
splitted = regexp.split(data)
tokens = []
for i, split in enumerate(splitted):
if i % 2 == 0:
tokens += split_by(split, regexps[1:])
else:
tokens.append(split)
return tokens
tokens = split_by(data, self.REGEXP_ORDER)
if with_begin_and_end:
tokens = ["^"] + tokens + ["$"]
return tokens
def untokenize(self, tokens):
"""Untokenizes a SMILES string."""
smi = ""
for token in tokens:
if token == "$":
break
if token != "^":
smi += token
return smi
def create_vocabulary(smiles_list, tokenizer, property_condition=None):
"""Creates a vocabulary for the SMILES syntax."""
tokens = set()
for smi in smiles_list:
tokens.update(tokenizer.tokenize(smi, with_begin_and_end=False))
vocabulary = Vocabulary()
vocabulary.update(["*", "^", "$"] + sorted(tokens)) # pad=0, start=1, end=2
if property_condition is not None:
vocabulary.update(property_condition)
# for random smiles
if "8" not in vocabulary.tokens():
vocabulary.update(["8"])
return vocabulary