File size: 4,824 Bytes
57a6212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963fdfb
57a6212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86bd978
57a6212
 
 
 
 
 
 
 
 
 
86bd978
 
57a6212
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from model import encoder_text
import torch, clip, random
import numpy as np
device = torch.device("cpu")

from words import words

########## SENTENCE PART #######################################################
voyelles = ["a","e","i","o","u"]
links    = list(words.keys())[1:]

def link_text(part,nextWord):
    ### Check if we need to write "... a", "... an", "..."
    if (len(part["link"]) > 0) and (part["link"][-1] == "a"):
        voyelleStart = (nextWord[0] in voyelles)
        plural       = (nextWord[-1] == "s" and nextWord[-2] != "s") or (nextWord in ["nothing","hair","vampire teeth","something"])
    else:
        voyelleStart, plural = False, False
    return (part["link"][:-2] if plural else part["link"] + ("n" if voyelleStart else ""))

def part_text(part):
    l = link_text(part,part["word"])
    return l + (" " if len(l)>0 else "") + part["word"]

def compute_embeddings(part,var_dict,prefix,batch_size=64):
    target        = part["word"]
    possibleWords = list(set(words[part["link"]]) - set([target]+var_dict["found_words"]))
    if len(possibleWords) > (batch_size-1): possibleWords = np.random.choice(list(possibleWords),batch_size-1,replace=False).tolist()
    possibleWords.append(target)
    ### Compute all classes & embeddings for current sentence part
    part["classes"] = [prefix + link_text(part,w) + (" " if len(link_text(part,w))>0 else "") + w for w in possibleWords]
    with torch.no_grad():
        embeddings         = encoder_text(clip.tokenize(part["classes"]).to(device))
        embeddings        /= embeddings.norm(dim=-1, keepdim=True)
        part["embeddings"] = embeddings.tolist()

########## SENTENCE ############################################################
def iniSentence(var_dict,input="",first_game=False):
    var_dict["found_words"] = []
    var_dict["parts"]       = []
    var_dict["step"]        = 0
    prefix                  = ""
    N                       = (2 if var_dict["difficulty"] == 1 else 1)

    if first_game:
        link = "a drawing of a"
        part = {"link":link,"word":"cat","classes":[],"embeddings":[]}
        var_dict["parts"].append(part)
        compute_embeddings(part, var_dict, prefix)
        prefix += part_text(part) + " "

        link = "with a"
        part = {"link":link,"word":"face","classes":[],"embeddings":[]}
        var_dict["parts"].append(part)
        compute_embeddings(part, var_dict, prefix)
        prefix += part_text(part) + " "
    else:
        ##### Generating Random Sentence
        link = "a drawing of a"
        part = {"link":link,"word":np.random.choice(words[link]),"classes":[],"embeddings":[]}
        var_dict["parts"].append(part)
        compute_embeddings(part, var_dict, prefix)
        prefix += part_text(part) + " "

        for i in range(N-1):
            link  = np.random.choice(links)
            part  = {"link":link,"word":np.random.choice(words[link][1:]),"classes":[],"embeddings":[]}
            var_dict["parts"].append(part)
            compute_embeddings(part, var_dict, prefix)
            prefix += part_text(part) + " "

    var_dict["target_sentence"] = prefix[:-1] # Target sentence is prefix without the last space
    setState(var_dict)
    return var_dict["target_sentence"]

def prevState(var_dict):
    if len(var_dict["prev_steps"]) > 0: var_dict["step"] = var_dict["prev_steps"].pop(-1)
    else:                               var_dict["step"] = 0
    var_dict["revertedState"] = True
    setState(var_dict)

def setState(var_dict):
    var_dict["found_words"] = var_dict["found_words"][:var_dict["step"]]
    var_dict["guessed_sentence"] = ""
    for i in range(var_dict["step"]):
        var_dict["guessed_sentence"] += part_text(var_dict["parts"][i]) + " "

def updateState(var_dict, preds):
    if not var_dict["revertedState"]:  var_dict["prev_steps"].append(var_dict["step"])
    else:                              var_dict["revertedState"] = False

    ### Check if the current part has been guessed
    part = var_dict["parts"][var_dict["step"]]

    
    idx_of_nothing = -1
    if ("nothing" in preds[0]):   idx_of_nothing = 0
    elif ("nothing" in preds[1]): idx_of_nothing = 1
    elif ("nothing" in preds[2]): idx_of_nothing = 2

    idx_of_guess = -1
    if (part["classes"][-1] == preds[0]):   idx_of_guess = 0
    elif (part["classes"][-1] == preds[1]): idx_of_guess = 1
    elif (part["classes"][-1] == preds[2]): idx_of_guess = 2


    if not var_dict["win"] and (idx_of_guess > idx_of_nothing):
        var_dict["step"] += 1
        var_dict["found_words"].append(part["word"])
        var_dict["win"] = var_dict["step"] == len(var_dict["parts"])
        setState(var_dict)
        if var_dict["win"]: return 1
        else:               return 0
    elif not var_dict["win"]: return -1
    else:                     return 1