Spaces:
Runtime error
Runtime error
import pickle | |
import sys | |
import torch | |
class NGram(): | |
def __init__(self, corpus, corpus_counts, type): | |
self.corpus = corpus | |
self.counts = corpus_counts | |
self.type = type | |
def prob(self, key, next): | |
""" | |
Args: | |
key (tuple): tuple of token ID's forming prior | |
next (int): probability of next token | |
""" | |
l = len(key) | |
if self.type == "bigram": | |
assert l == 1 | |
key = key[0] | |
elif self.type == "trigram": | |
assert l == 2 | |
elif self.type == "fourgram": | |
assert l == 3 | |
elif self.type == "fivegram": | |
assert l == 4 | |
elif self.type == "sixgram": | |
assert l == 5 | |
elif self.type == "sevengram": | |
assert l == 6 | |
count = 0 | |
if key in self.corpus: | |
count = self.corpus[key].get(next, 0) | |
total = sum(self.corpus[key].values()) | |
return count / total | |
else: | |
return -1 | |
def ntd(self, key, vocab_size=32000): | |
""" | |
Args: | |
key (tuple): tuple of token ID's forming prior | |
Returns: | |
prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities | |
""" | |
if key in self.corpus: | |
prob_tensor = torch.zeros(vocab_size) | |
total = sum(self.corpus[key].values()) | |
for next_token in self.corpus[key]: | |
prob_tensor[next_token] = self.corpus[key][next_token] / total | |
return prob_tensor | |
else: | |
return None | |
def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram): | |
""" | |
Loads and returns a list correspoding to bigram to sevengram models, containing | |
the models that whose parameters are `True`. See below for expected corpus names. | |
Args: | |
ckpt_path (str): Location of ngram models | |
bigram-sevengram: Which models to load | |
Returns: | |
List of n-gram models | |
""" | |
models = [] | |
if bigram: | |
print("Making bigram...") | |
with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f: | |
bigram = pickle.load(f) | |
bigram_model = NGram(bigram, None, "bigram") | |
models.append(bigram_model) | |
print(sys.getsizeof(bigram)) | |
if trigram: | |
print("Making trigram...") | |
with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f: | |
trigram = pickle.load(f) | |
trigram_model = NGram(trigram, None, "trigram") | |
models.append(trigram_model) | |
print(sys.getsizeof(trigram)) | |
if fourgram: | |
print("Making fourgram...") | |
with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f: | |
fourgram = pickle.load(f) | |
fourgram_model = NGram(fourgram, None, "fourgram") | |
models.append(fourgram_model) | |
print(sys.getsizeof(fourgram)) | |
if fivegram: | |
print("Making fivegram...") | |
with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f: | |
fivegram = pickle.load(f) | |
fivegram_model = NGram(fivegram, None, "fivegram") | |
models.append(fivegram_model) | |
print(sys.getsizeof(fivegram)) | |
if sixgram: | |
print("Making sixgram...") | |
with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f: | |
sixgram = pickle.load(f) | |
sixgram_model = NGram(sixgram, None, "sixgram") | |
models.append(sixgram_model) | |
print(sys.getsizeof(sixgram)) | |
if sevengram: | |
print("Making sevengram...") | |
with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f: | |
sevengram = pickle.load(f) | |
sevengram_model = NGram(sevengram, None, "sevengram") | |
models.append(sevengram_model) | |
return models |