Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import glob | |
import numpy as np | |
DIM = 1024 | |
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): | |
target_ids = [tid for tid in target_embs] | |
source_mat = np.stack(source_embs.values(), axis=0) | |
normalized_source_mat = source_mat / np.linalg.norm( | |
source_mat, axis=1, keepdims=True | |
) | |
target_mat = np.stack(target_embs.values(), axis=0) | |
normalized_target_mat = target_mat / np.linalg.norm( | |
target_mat, axis=1, keepdims=True | |
) | |
sim_mat = normalized_source_mat.dot(normalized_target_mat.T) | |
if return_sim_mat: | |
return sim_mat | |
neighbors_map = {} | |
for i, sentence_id in enumerate(source_embs): | |
idx = np.argsort(sim_mat[i, :])[::-1][:k] | |
neighbors_map[sentence_id] = [target_ids[tid] for tid in idx] | |
return neighbors_map | |
def load_embeddings(directory, LANGS): | |
sentence_embeddings = {} | |
sentence_texts = {} | |
for lang in LANGS: | |
sentence_embeddings[lang] = {} | |
sentence_texts[lang] = {} | |
lang_dir = f"{directory}/{lang}" | |
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") | |
for embed_file in embedding_files: | |
shard_id = embed_file.split(".")[-1] | |
embeddings = np.fromfile(embed_file, dtype=np.float32) | |
num_rows = embeddings.shape[0] // DIM | |
embeddings = embeddings.reshape((num_rows, DIM)) | |
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file: | |
for idx, line in enumerate(sentence_file): | |
sentence_id, sentence = line.strip().split("\t") | |
sentence_texts[lang][sentence_id] = sentence | |
sentence_embeddings[lang][sentence_id] = embeddings[idx, :] | |
return sentence_embeddings, sentence_texts | |
def compute_accuracy(directory, LANGS): | |
sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS) | |
top_1_accuracy = {} | |
top1_str = " ".join(LANGS) + "\n" | |
for source_lang in LANGS: | |
top_1_accuracy[source_lang] = {} | |
top1_str += f"{source_lang} " | |
for target_lang in LANGS: | |
top1 = 0 | |
top5 = 0 | |
neighbors_map = compute_dist( | |
sentence_embeddings[source_lang], sentence_embeddings[target_lang] | |
) | |
for sentence_id, neighbors in neighbors_map.items(): | |
if sentence_id == neighbors[0]: | |
top1 += 1 | |
if sentence_id in neighbors[:5]: | |
top5 += 1 | |
n = len(sentence_embeddings[target_lang]) | |
top1_str += f"{top1/n} " | |
top1_str += "\n" | |
print(top1_str) | |
print(top1_str, file=open(f"{directory}/accuracy", "w")) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Analyze encoder outputs") | |
parser.add_argument("directory", help="Source language corpus") | |
parser.add_argument("--langs", help="List of langs") | |
args = parser.parse_args() | |
langs = args.langs.split(",") | |
compute_accuracy(args.directory, langs) | |