from model import encoder_text import torch, clip, random import numpy as np device = torch.device("cpu") from words import words ########## SENTENCE PART ####################################################### voyelles = ["a","e","i","o","u"] links = list(words.keys())[1:] def link_text(part,nextWord): ### Check if we need to write "... a", "... an", "..." if (len(part["link"]) > 0) and (part["link"][-1] == "a"): voyelleStart = (nextWord[0] in voyelles) plural = (nextWord[-1] == "s" and nextWord[-2] != "s") or (nextWord in ["nothing","hair","vampire teeth","something"]) else: voyelleStart, plural = False, False return (part["link"][:-2] if plural else part["link"] + ("n" if voyelleStart else "")) def part_text(part): l = link_text(part,part["word"]) return l + (" " if len(l)>0 else "") + part["word"] def compute_embeddings(part,var_dict,prefix,batch_size=64): target = part["word"] possibleWords = list(set(words[part["link"]]) - set([target]+var_dict["found_words"])) if len(possibleWords) > (batch_size-1): possibleWords = np.random.choice(list(possibleWords),batch_size-1,replace=False).tolist() possibleWords.append(target) ### Compute all classes & embeddings for current sentence part part["classes"] = [prefix + link_text(part,w) + (" " if len(link_text(part,w))>0 else "") + w for w in possibleWords] with torch.no_grad(): embeddings = encoder_text(clip.tokenize(part["classes"]).to(device)) embeddings /= embeddings.norm(dim=-1, keepdim=True) part["embeddings"] = embeddings.tolist() ########## SENTENCE ############################################################ def iniSentence(var_dict,input="",first_game=False): var_dict["found_words"] = [] var_dict["parts"] = [] var_dict["step"] = 0 prefix = "" N = (2 if var_dict["difficulty"] == 1 else 1) if first_game: link = "a drawing of a" part = {"link":link,"word":"cat","classes":[],"embeddings":[]} var_dict["parts"].append(part) compute_embeddings(part, var_dict, prefix) prefix += part_text(part) + " " link = "with a" part = {"link":link,"word":"face","classes":[],"embeddings":[]} var_dict["parts"].append(part) compute_embeddings(part, var_dict, prefix) prefix += part_text(part) + " " else: ##### Generating Random Sentence link = "a drawing of a" part = {"link":link,"word":np.random.choice(words[link]),"classes":[],"embeddings":[]} var_dict["parts"].append(part) compute_embeddings(part, var_dict, prefix) prefix += part_text(part) + " " for i in range(N-1): link = np.random.choice(links) part = {"link":link,"word":np.random.choice(words[link][1:]),"classes":[],"embeddings":[]} var_dict["parts"].append(part) compute_embeddings(part, var_dict, prefix) prefix += part_text(part) + " " var_dict["target_sentence"] = prefix[:-1] # Target sentence is prefix without the last space setState(var_dict) return var_dict["target_sentence"] def prevState(var_dict): if len(var_dict["prev_steps"]) > 0: var_dict["step"] = var_dict["prev_steps"].pop(-1) else: var_dict["step"] = 0 var_dict["revertedState"] = True setState(var_dict) def setState(var_dict): var_dict["found_words"] = var_dict["found_words"][:var_dict["step"]] var_dict["guessed_sentence"] = "" for i in range(var_dict["step"]): var_dict["guessed_sentence"] += part_text(var_dict["parts"][i]) + " " def updateState(var_dict, preds): if not var_dict["revertedState"]: var_dict["prev_steps"].append(var_dict["step"]) else: var_dict["revertedState"] = False ### Check if the current part has been guessed part = var_dict["parts"][var_dict["step"]] idx_of_nothing = -1 if ("nothing" in preds[0]): idx_of_nothing = 0 elif ("nothing" in preds[1]): idx_of_nothing = 1 elif ("nothing" in preds[2]): idx_of_nothing = 2 idx_of_guess = -1 if (part["classes"][-1] == preds[0]): idx_of_guess = 0 elif (part["classes"][-1] == preds[1]): idx_of_guess = 1 elif (part["classes"][-1] == preds[2]): idx_of_guess = 2 if not var_dict["win"] and (idx_of_guess > idx_of_nothing): var_dict["step"] += 1 var_dict["found_words"].append(part["word"]) var_dict["win"] = var_dict["step"] == len(var_dict["parts"]) setState(var_dict) if var_dict["win"]: return 1 else: return 0 elif not var_dict["win"]: return -1 else: return 1