import numpy as np import pickle import embedding import random import embed_set import net from tqdm import tqdm from tensorflow.keras.models import load_model top_p = 1 class SetLine: def __init__(self, name, inp): self.name = name self.inp = embedding.getvec(name) with open("set.pckl", "rb") as f: dset = pickle.load(f) model = load_model("net.h5") def top_closest_vectors(input_vector, top_p=1): distances = [(np.linalg.norm((neuron.inp - input_vector)), ind) for ind, neuron in enumerate(dset)] closest_indices = sorted(distances, reverse=False, key=lambda x:x[0])[:top_p] return closest_indices def generate(text): vecs = 3*[np.zeros(net.vec_size),] + [embedding.getvec(x) for x in text.split("\n")] vecs = vecs[-3:] vecs = np.array([vecs,]) rvec = model.predict(vecs)[0] return dset[random.choice(top_closest_vectors(rvec))[1]].name