Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,427 Bytes
8133f69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import math
import os
import json
import pickle
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
EMPTY_DATA_PATH = "tangram_pngs/"
CLIP_FOLDER = "clip_similarities"
def generate_complete_game():
# First get corpus and clip model
curr_corpus = get_data()
clip_files = os.listdir(CLIP_FOLDER)
clip_model = {}
for filename in clip_files:
# Get values
with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f:
curr_similarities = pickle.load(f)
# Get keys
tangram_name = '-'.join(filename.split('-')[:2])
clip_model[tangram_name] = curr_similarities
# Next get the pragmatic context
context_dict = get_pragmatic_context(curr_corpus, clip_model)
return context_dict
def get_pragmatic_context(curr_corpus, clip_model):
# Initialize the lists needed for generation
overall_context = []
base_tangrams = []
individual_blocks = []
# Initialize the parameters for generation
block_sizes = evenly_spread_values(10, 3)
for i in range(3):
# Sample the base tangram
base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context)
base_tangrams.append(base_tangram)
# Sample the similarity block
similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) # TODO
individual_blocks.append(similarity_block)
overall_context.extend(similarity_block)
# Filter out the corpus
curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context]
# Sample the targets at random
targets = random.sample(overall_context, 3)
# Construct the dictionary
speaker_order = list(range(len(overall_context)))
random.shuffle(speaker_order)
speaker_images = [overall_context[i] for i in speaker_order]
listener_order = list(range(len(overall_context)))
random.shuffle(listener_order)
listener_images = [overall_context[i] for i in listener_order]
context_dict = {
"speaker_context" : speaker_images,
"listener_context" : listener_images,
"targets" : targets,
}
return context_dict
def evenly_spread_values(block_size, num_similarity_blocks):
sim_block_sizes = [0 for _ in range(num_similarity_blocks)]
for i in range(block_size):
idx = i % num_similarity_blocks
sim_block_sizes[idx] += 1
return sim_block_sizes
def sample_similarity_block_base(curr_corpus, clip_model, overall_context):
# Get list of candidate tangrams
candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model,
overall_context)
base_tangram = random.sample(candidate_base_tangrams, 1)[0]
return base_tangram
def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context):
candidate_base_tangrams = []
for tangram in curr_corpus:
if valid_base_tangram(overall_context, tangram, clip_model):
candidate_base_tangrams.append(tangram)
return candidate_base_tangrams
def valid_base_tangram(overall_context, tangram, clip_model):
for context_tangram in overall_context:
if clip_model[context_tangram[:-4]][tangram[:-4]] > 1:
return False
return True
def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size,
clip_model):
# Get the most similar tangrams to the base tangram
base_similarities = clip_model[base_tangram[:-4]]
sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1])
sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus]
# Separate out the tangrams and the scores
sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities]
sorted_scores = [sim[1] for sim in sorted_similarities]
k = similarity_block_size - 1
distribution = get_similarity_distribution(sorted_scores, 0.055)
sampled_indices = sample_without_replacement(distribution, k)
similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices]
return similarity_block
def get_similarity_distribution(scores, temperature):
logits = torch.Tensor([score / temperature for score in scores])
probs = F.softmax(logits, dim=0)
return probs
def sample_without_replacement(distribution, K):
new_distribution = torch.clone(distribution)
samples = []
for i in range(K):
current_sample = torch.multinomial(new_distribution, num_samples=1).item()
samples.append(current_sample)
new_distribution[current_sample] = 0
new_distribution = new_distribution / torch.sum(new_distribution)
return samples
def get_data(restricted_dataset=""):
# Get the list of all paths
if restricted_dataset == "":
paths = os.listdir(EMPTY_DATA_PATH)
else:
with open(restricted_dataset, 'rb') as f:
paths = pickle.load(f)
paths = [path + ".svg" for path in paths]
paths = [path for path in paths if ".DS_Store" not in path]
random.shuffle(paths)
# Remove duplicates
for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']:
if duplicate in paths:
paths.remove(duplicate)
return paths
|