xuan3986's picture
Upload 111 files
03022ee verified
import logging
import torch
import pickle
import numpy as np
from funcineforge.utils.hinter import hint_once
from funcineforge.datasets import FunCineForgeDS
from funcineforge.models import FunCineForgeSpecAug
class FunCineForgeDataset(torch.utils.data.Dataset):
"""
Dataset for Mixed LM of FunCineForge
"""
def __init__(
self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
face_encoder=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs,
):
super().__init__()
self.index_ds = FunCineForgeDS(path, **kwargs)
self.tokenizer = tokenizer
self.face_encoder = face_encoder
self.int_pad_value = int_pad_value
self.float_pad_value = float_pad_value
self.batch_size = kwargs.get("batch_size")
self.batch_type = kwargs.get("batch_type")
self.retry = kwargs.get("retry", 100)
# self.kwargs = kwargs
self.max_token_length = kwargs.get("max_token_length", 1500)
self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
self.multiturn_num_max = kwargs.get("multiturn_num_max", 1)
self.face_size = kwargs.get("face_size", 512)
self.codebook_size = kwargs.get("codebook_size", 6561)
self.sos = kwargs.get("sos", self.codebook_size)
self.eos = kwargs.get("eos", self.codebook_size + 1)
self.turn_of_speech = kwargs.get("turn_of_speech", self.codebook_size + 2)
self.ignore_id = kwargs.get("ignore_id", -100)
specaug = kwargs.get("specaug", None)
specaug_conf = kwargs.get("specaug_conf", {})
if specaug is not None:
specaug = FunCineForgeSpecAug(**specaug_conf)
self.specaug = specaug
self.set_invalid_xvec_zeros = kwargs.get("set_invalid_xvec_zeros", False)
self.use_emotion_clue = kwargs.get("use_emotion_clue", False)
logging.info(f"use_emotion_clue: {self.use_emotion_clue}")
def get_source_len(self, index):
item = self.index_ds[index]
source_len = self.index_ds.get_source_len(item)
return source_len
def get_target_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
def __len__(self):
return len(self.index_ds)
def mixup_text_codec(self, text: torch.Tensor, aug_codec: torch.Tensor, timespk_ids: torch.Tensor, type_id: int):
text_len = text.shape[0]
timespk_len = timespk_ids.shape[0]
sequence = [self.sos, *text.tolist(), type_id, *timespk_ids.tolist(), self.turn_of_speech, *aug_codec.tolist(), self.eos]
# sequence = [self.sos, *text.tolist(), type_id, self.turn_of_speech, *aug_codec.tolist(), self.eos]
input_ids = torch.tensor(sequence, dtype=torch.int64)
text_flag = torch.zeros(len(sequence), dtype=torch.float32)
text_flag[1:text_len+1] = 1
timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
timespk_flag[text_len+1:text_len+2+timespk_len] = 1
# timespk_flag[text_len+1:text_len+2] = 1
codec_flag = 1 - (text_flag + timespk_flag)
labels = torch.tensor(sequence, dtype=torch.int64)
labels[:text_len+timespk_len+3] = self.ignore_id
# labels[:text_len+3] = self.ignore_id
return input_ids, labels, text_flag, codec_flag, timespk_flag
def __getitem__(self, index):
output = None
for idx in range(self.retry):
if idx == 0:
index_cur = index
else:
index_cur = torch.randint(0, len(self.index_ds), ()).item()
item = self.index_ds[index_cur]
# clue + text
text = item["text"]
clue = "<|startofclue|>" + item["clue"] + "<|endofclue|>"
if self.use_emotion_clue:
text = clue + text
text_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.int32)
hint_once(f"raw text: {text}", "log_text")
# speech tokens
target_out = item["token"]
codec = torch.from_numpy(np.load(target_out))
codec_len = codec.shape[0] # 可用数据集中的 speech_length 代替
aug_codec = codec.clone()
if self.specaug is not None: # aug_codec是随机mask的codec增强鲁棒性
aug_codec, _ = self.specaug(aug_codec.float().unsqueeze(0).unsqueeze(-1))
aug_codec = aug_codec.squeeze(0).squeeze(-1).long()
# dialogue
timespk_ids = torch.from_numpy(item["timespk_ids"])
# mixup
type_id = item["type_id"]
input_ids, labels, text_flag, codec_flag, timespk_flag = self.mixup_text_codec(
text_ids, aug_codec, timespk_ids, type_id
)
# face
face_features = item["face"]
face_emb = torch.zeros((codec_len, self.face_size), dtype=torch.float32) # face_emb 长度与 codec_len 相同
with open(face_features, 'rb') as f:
stat_obj = pickle.load(f)
embeddings = stat_obj['embeddings']
faceI = stat_obj['faceI']
for emb, frameI in zip(embeddings, faceI):
fi = int(frameI)
if 0 <= fi < codec_len:
end = min(fi + 5, codec_len)
face_emb[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
# attention_mask 对应序列长度包括input_id=(sos, <|startofclue|>, clue, <|endofclue|>, text, type_id, timespk_ids, turn_of_speech, speech, eos)
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
codec_len = torch.tensor([codec_len], dtype=torch.int32)
output = {
"input_ids": input_ids,
"face_emb": face_emb,
"attention_mask": attention_mask,
"labels_ids": labels,
"text_flag": text_flag,
"codec_flag": codec_flag,
"timespk_flag": timespk_flag,
"codec_len": codec_len,
}
break
return output
def collator(self, samples: list = None):
for idx in range(self.retry):
badcase_flag = False
outputs = {}
for sample in samples:
if sample is None:
continue
for key in sample.keys():
if key not in outputs:
outputs[key] = []
if isinstance(sample[key], (list, tuple)):
outputs[key].extend(sample[key])
else:
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
if self.batch_type != "example":
b, t = outputs["input_ids"].shape
if b > 1 and b * t > self.batch_size_token_max:
logging.info(
f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_token_max: {self.batch_size_token_max}, drop last data"
)
samples = samples[:-1]
continue
break
return outputs