MassivelyMultilingualTTS / Preprocessing /multilinguality /generate_zero_shot_lang_embs.py
Flux9665's picture
update to the current version
70399da
raw
history blame
4.89 kB
import argparse
import json
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from Utility.storage_config import MODELS_DIR
def approximate_and_inject_language_embeddings(model_path, df, iso_lookup, min_n_langs=5, max_n_langs=25, threshold_percentile=50):
# load pretrained language_embeddings
model = torch.load(model_path, map_location="cpu")
lang_embs = model["model"]["encoder.language_embedding.weight"]
features_per_closest_lang = 2
# for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column
if "combined_dist_0" in df.columns:
if "map_dist_0" in df.columns:
features_per_closest_lang += 1
if "asp_dist_0" in df.columns:
features_per_closest_lang += 1
if "tree_dist_0" in df.columns:
features_per_closest_lang += 1
n_closest = len(df.columns) // features_per_closest_lang
distance_type = "combined"
# else, df has 2 features per closest lang + 1 target lang column
else:
n_closest = len(df.columns) // features_per_closest_lang
if "map_dist_0" in df.columns:
distance_type = "map"
elif "tree_dist_0" in df.columns:
distance_type = "tree"
elif "asp_dist_0" in df.columns:
distance_type = "asp"
elif "learned_dist_0" in df.columns:
distance_type = "learned"
else:
distance_type = "random"
# get relevant columns
closest_lang_columns = [f"closest_lang_{i}" for i in range(n_closest)]
closest_dist_columns = [f"{distance_type}_dist_{i}" for i in range(n_closest)]
closest_lang_columns = closest_lang_columns[:max_n_langs]
closest_dist_columns = closest_dist_columns[:max_n_langs]
assert df[closest_dist_columns[-1]].isna().sum().sum() == 0
# get threshold based on distance of a certain percentile of the furthest language across all samples
threshold = np.percentile(df[closest_dist_columns[-1]], threshold_percentile)
print(f"threshold: {threshold:.4f}")
for row in tqdm(df.itertuples(), total=df.shape[0], desc="Approximating language embeddings"):
avg_emb = torch.zeros([32]) # If you change the size of the language embedding in the model, you need to change the size here as well. TODO automate this
dists = [getattr(row, d) for i, d in enumerate(closest_dist_columns) if i < min_n_langs or getattr(row, d) < threshold]
langs = [getattr(row, l) for l in closest_lang_columns[:len(dists)]]
for lang in langs:
lang_emb = lang_embs[iso_lookup[-1][str(lang)]]
avg_emb += lang_emb
avg_emb /= len(langs) # normalize
lang_embs[iso_lookup[-1][str(row.target_lang)]] = avg_emb
# inject language embeddings into Toucan model and save
model["model"]["encoder.language_embedding.weight"] = lang_embs
modified_model_path = model_path.split(".")[0] + "_zeroshot_lang_embs.pt"
torch.save(model, modified_model_path)
print(f"Replaced unsupervised language embeddings with zero-shot approximations.\nSaved modified model to {modified_model_path}")
if __name__ == "__main__":
default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location
default_csv_path = "distance_datasets/dataset_learned_top30.csv"
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=default_model_path, help="path of the model for which the language embeddings should be modified")
parser.add_argument("--dataset_path", type=str, default=default_csv_path, help="path to distance dataset CSV")
parser.add_argument("--min_n_langs", type=int, default=5, help="minimum amount of languages used for averaging")
parser.add_argument("--max_n_langs", type=int, default=25, help="maximum amount of languages used for averaging")
parser.add_argument("--threshold_percentile", type=int, default=50, help="percentile of the furthest used languages \
used as cutoff threshold (no langs >= the threshold are used for averaging)")
args = parser.parse_args()
ISO_LOOKUP_PATH = "iso_lookup.json"
with open(ISO_LOOKUP_PATH, "r") as f:
iso_lookup = json.load(f) # iso_lookup[-1] = iso2id mapping
# load language distance dataset
distance_df = pd.read_csv(args.dataset_path, sep="|")
approximate_and_inject_language_embeddings(model_path=args.model_path,
df=distance_df,
iso_lookup=iso_lookup,
min_n_langs=args.min_n_langs,
max_n_langs=args.max_n_langs,
threshold_percentile=args.threshold_percentile)