Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import os | |
import json | |
import pickle | |
import random | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import numpy as np | |
EMPTY_DATA_PATH = "tangram_pngs/" | |
CLIP_FOLDER = "clip_similarities" | |
def generate_complete_game(): | |
# First get corpus and clip model | |
curr_corpus = get_data() | |
clip_files = os.listdir(CLIP_FOLDER) | |
clip_model = {} | |
for filename in clip_files: | |
# Get values | |
with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f: | |
curr_similarities = pickle.load(f) | |
# Get keys | |
tangram_name = '-'.join(filename.split('-')[:2]) | |
clip_model[tangram_name] = curr_similarities | |
# Next get the pragmatic context | |
context_dict = get_pragmatic_context(curr_corpus, clip_model) | |
return context_dict | |
def get_pragmatic_context(curr_corpus, clip_model): | |
# Initialize the lists needed for generation | |
overall_context = [] | |
base_tangrams = [] | |
individual_blocks = [] | |
# Initialize the parameters for generation | |
block_sizes = evenly_spread_values(10, 3) | |
for i in range(3): | |
# Sample the base tangram | |
base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context) | |
base_tangrams.append(base_tangram) | |
# Sample the similarity block | |
similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) # TODO | |
individual_blocks.append(similarity_block) | |
overall_context.extend(similarity_block) | |
# Filter out the corpus | |
curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context] | |
# Sample the targets at random | |
targets = random.sample(overall_context, 3) | |
# Construct the dictionary | |
speaker_order = list(range(len(overall_context))) | |
random.shuffle(speaker_order) | |
speaker_images = [overall_context[i] for i in speaker_order] | |
listener_order = list(range(len(overall_context))) | |
random.shuffle(listener_order) | |
listener_images = [overall_context[i] for i in listener_order] | |
context_dict = { | |
"speaker_context" : speaker_images, | |
"listener_context" : listener_images, | |
"targets" : targets, | |
} | |
return context_dict | |
def evenly_spread_values(block_size, num_similarity_blocks): | |
sim_block_sizes = [0 for _ in range(num_similarity_blocks)] | |
for i in range(block_size): | |
idx = i % num_similarity_blocks | |
sim_block_sizes[idx] += 1 | |
return sim_block_sizes | |
def sample_similarity_block_base(curr_corpus, clip_model, overall_context): | |
# Get list of candidate tangrams | |
candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model, | |
overall_context) | |
base_tangram = random.sample(candidate_base_tangrams, 1)[0] | |
return base_tangram | |
def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context): | |
candidate_base_tangrams = [] | |
for tangram in curr_corpus: | |
if valid_base_tangram(overall_context, tangram, clip_model): | |
candidate_base_tangrams.append(tangram) | |
return candidate_base_tangrams | |
def valid_base_tangram(overall_context, tangram, clip_model): | |
for context_tangram in overall_context: | |
if clip_model[context_tangram[:-4]][tangram[:-4]] > 1: | |
return False | |
return True | |
def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size, | |
clip_model): | |
# Get the most similar tangrams to the base tangram | |
base_similarities = clip_model[base_tangram[:-4]] | |
sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1]) | |
sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus] | |
# Separate out the tangrams and the scores | |
sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities] | |
sorted_scores = [sim[1] for sim in sorted_similarities] | |
k = similarity_block_size - 1 | |
distribution = get_similarity_distribution(sorted_scores, 0.055) | |
sampled_indices = sample_without_replacement(distribution, k) | |
similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices] | |
return similarity_block | |
def get_similarity_distribution(scores, temperature): | |
logits = torch.Tensor([score / temperature for score in scores]) | |
probs = F.softmax(logits, dim=0) | |
return probs | |
def sample_without_replacement(distribution, K): | |
new_distribution = torch.clone(distribution) | |
samples = [] | |
for i in range(K): | |
current_sample = torch.multinomial(new_distribution, num_samples=1).item() | |
samples.append(current_sample) | |
new_distribution[current_sample] = 0 | |
new_distribution = new_distribution / torch.sum(new_distribution) | |
return samples | |
def get_data(restricted_dataset=""): | |
# Get the list of all paths | |
if restricted_dataset == "": | |
paths = os.listdir(EMPTY_DATA_PATH) | |
else: | |
with open(restricted_dataset, 'rb') as f: | |
paths = pickle.load(f) | |
paths = [path + ".svg" for path in paths] | |
paths = [path for path in paths if ".DS_Store" not in path] | |
random.shuffle(paths) | |
# Remove duplicates | |
for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']: | |
if duplicate in paths: | |
paths.remove(duplicate) | |
return paths | |