|
import math |
|
import os |
|
import json |
|
import pickle |
|
import random |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import numpy as np |
|
EMPTY_DATA_PATH = "tangram_pngs/" |
|
CLIP_FOLDER = "clip_similarities" |
|
|
|
def generate_complete_game(): |
|
|
|
curr_corpus = get_data() |
|
|
|
clip_files = os.listdir(CLIP_FOLDER) |
|
clip_model = {} |
|
for filename in clip_files: |
|
|
|
with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f: |
|
curr_similarities = pickle.load(f) |
|
|
|
|
|
tangram_name = '-'.join(filename.split('-')[:2]) |
|
clip_model[tangram_name] = curr_similarities |
|
|
|
|
|
context_dict = get_pragmatic_context(curr_corpus, clip_model) |
|
return context_dict |
|
|
|
def get_pragmatic_context(curr_corpus, clip_model): |
|
|
|
overall_context = [] |
|
base_tangrams = [] |
|
individual_blocks = [] |
|
|
|
|
|
block_sizes = evenly_spread_values(10, 3) |
|
|
|
for i in range(3): |
|
|
|
base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context) |
|
base_tangrams.append(base_tangram) |
|
|
|
|
|
similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) |
|
individual_blocks.append(similarity_block) |
|
overall_context.extend(similarity_block) |
|
|
|
|
|
curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context] |
|
|
|
|
|
targets = random.sample(overall_context, 3) |
|
|
|
|
|
speaker_order = list(range(len(overall_context))) |
|
random.shuffle(speaker_order) |
|
speaker_images = [overall_context[i] for i in speaker_order] |
|
|
|
listener_order = list(range(len(overall_context))) |
|
random.shuffle(listener_order) |
|
listener_images = [overall_context[i] for i in listener_order] |
|
|
|
context_dict = { |
|
"speaker_context" : speaker_images, |
|
"listener_context" : listener_images, |
|
"targets" : targets, |
|
} |
|
|
|
return context_dict |
|
|
|
def evenly_spread_values(block_size, num_similarity_blocks): |
|
sim_block_sizes = [0 for _ in range(num_similarity_blocks)] |
|
for i in range(block_size): |
|
idx = i % num_similarity_blocks |
|
sim_block_sizes[idx] += 1 |
|
return sim_block_sizes |
|
|
|
def sample_similarity_block_base(curr_corpus, clip_model, overall_context): |
|
|
|
candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model, |
|
overall_context) |
|
|
|
base_tangram = random.sample(candidate_base_tangrams, 1)[0] |
|
return base_tangram |
|
|
|
def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context): |
|
candidate_base_tangrams = [] |
|
for tangram in curr_corpus: |
|
if valid_base_tangram(overall_context, tangram, clip_model): |
|
candidate_base_tangrams.append(tangram) |
|
return candidate_base_tangrams |
|
|
|
def valid_base_tangram(overall_context, tangram, clip_model): |
|
for context_tangram in overall_context: |
|
if clip_model[context_tangram[:-4]][tangram[:-4]] > 1: |
|
return False |
|
return True |
|
|
|
def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size, |
|
clip_model): |
|
|
|
base_similarities = clip_model[base_tangram[:-4]] |
|
sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1]) |
|
sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus] |
|
|
|
|
|
sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities] |
|
sorted_scores = [sim[1] for sim in sorted_similarities] |
|
k = similarity_block_size - 1 |
|
|
|
distribution = get_similarity_distribution(sorted_scores, 0.055) |
|
sampled_indices = sample_without_replacement(distribution, k) |
|
similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices] |
|
return similarity_block |
|
|
|
def get_similarity_distribution(scores, temperature): |
|
logits = torch.Tensor([score / temperature for score in scores]) |
|
probs = F.softmax(logits, dim=0) |
|
return probs |
|
|
|
def sample_without_replacement(distribution, K): |
|
new_distribution = torch.clone(distribution) |
|
|
|
samples = [] |
|
for i in range(K): |
|
current_sample = torch.multinomial(new_distribution, num_samples=1).item() |
|
samples.append(current_sample) |
|
|
|
new_distribution[current_sample] = 0 |
|
new_distribution = new_distribution / torch.sum(new_distribution) |
|
|
|
return samples |
|
|
|
def get_data(restricted_dataset=""): |
|
|
|
if restricted_dataset == "": |
|
paths = os.listdir(EMPTY_DATA_PATH) |
|
else: |
|
with open(restricted_dataset, 'rb') as f: |
|
paths = pickle.load(f) |
|
paths = [path + ".svg" for path in paths] |
|
paths = [path for path in paths if ".DS_Store" not in path] |
|
random.shuffle(paths) |
|
|
|
|
|
for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']: |
|
if duplicate in paths: |
|
paths.remove(duplicate) |
|
|
|
return paths |
|
|
|
|
|
|