PolyAI-pheme / modules /speech_tokenizer.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
raw
history blame
3.1 kB
"""Speech tokenizer class.
Copyright PolyAI Limited.
"""
import logging
import os
import numpy as np
import torch
import torchaudio
from speechtokenizer import SpeechTokenizer as ST
from modules.tokenizer import BaseTokenizer
class SpeechTokenizer(BaseTokenizer):
def __init__(self, config_path: str, ckpt_path: str):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.model = ST.load_from_checkpoint(
config_path, ckpt_path).to(self.device)
self.model.eval()
def encode_file(
self, folder_path: str, destination_folder: str, filename: str):
dest_path = os.path.join(
destination_folder, "semantic",
os.path.splitext(filename)[0] + ".npy"
)
dest_path2 = os.path.join(
destination_folder, "acoustic",
os.path.splitext(filename)[0] + ".npy"
)
if os.path.exists(dest_path) and os.path.exists(dest_path2):
pass
else:
self._create_subfolders(destination_folder=destination_folder)
file_path = os.path.join(folder_path, filename)
wav_info = torchaudio.info(file_path)
wav_dur_sec = wav_info.num_frames / wav_info.sample_rate
if wav_dur_sec > 60:
logging.info(
f"Skipping {file_path} is too long: {wav_dur_sec:.3f} sec,"
"can cause CUDA OOM"
)
return
wav, sr = torchaudio.load(file_path)
if sr != self.model.sample_rate:
logging.warning(
"Wav sample rate %(wav_sr)s does not match the model"
"sampling rate %(model_sr)s. Resampling audio",
{"wav_sr": sr, "model_sr": self.model.sample_rate},
)
wav = torchaudio.functional.resample(
wav, sr, self.model.sample_rate)
wav = wav.unsqueeze(0)
wav = wav.to(self.device)
# Extract discrete codes from SpeechTokenizer
with torch.no_grad():
codes = self.model.encode(wav) # codes: (n_q, B, T)
semantic_tokens = codes[0, 0, :]
acoustic_tokens = codes[1:, 0, :]
# Save the encoding as .npy
dest_path = os.path.join(
destination_folder, "acoustic",
os.path.splitext(filename)[0] + ".npy"
)
np.save(dest_path, acoustic_tokens.cpu().numpy())
dest_path = os.path.join(
destination_folder, "semantic",
os.path.splitext(filename)[0] + ".npy"
)
np.save(dest_path, semantic_tokens.cpu().numpy())
@staticmethod
def _create_subfolders(destination_folder: str):
if not os.path.exists(destination_folder + "/acoustic"):
os.makedirs(destination_folder + "/acoustic")
if not os.path.exists(destination_folder + "/semantic"):
os.makedirs(destination_folder + "/semantic")