|
import os |
|
import json |
|
import logging |
|
import numpy as np |
|
import torchaudio |
|
from torch.utils.data import Dataset |
|
|
|
|
|
def _handle_wav(wav_path, target_rate=16000): |
|
""" |
|
handle one wav file. |
|
Return: |
|
waveform: numpy narray(1d) |
|
""" |
|
waveform, sample_rate = torchaudio.load(wav_path) |
|
if sample_rate != target_rate: |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform) |
|
audio = waveform[0] |
|
return audio |
|
|
|
|
|
def _handle_qa(obj, is_think=True, think_max_len=50): |
|
if is_think: |
|
prompt_template = ( |
|
"# Dialogue Response Evaluation\n\n" |
|
"**IMPORTANT:** Evaluation must include `<think>` analysis and `<score>` rating.\n\n" |
|
"Listen to the dialogue recording (two sentences, 1-second pause in between). Evaluate the quality of the **second sentence** as a response to the first, focusing on **text relevance** and the **appropriateness** of **Linguistic information (a range of paralinguistic information such as emotion/age/pitch/speed/volume)**.\n" |
|
"**Note:** Focus on evaluating the appropriateness of the second sentence relative to the first, even if the first sentence itself contains contradictory information.\n\n" |
|
"## Scoring Criteria\n\n" |
|
"**1 points**: Text content is irrelevant or incorrect or illogical.(low intelligence)\n" |
|
"**3 points**: Text is relevant, but paralinguistic information is **inappropriate** for the context.(low emotional quotient)\n" |
|
"**5 points**: Text is relevant, and paralinguistic information is **appropriate** for the context, resulting in effective communication.(High intelligence and emotional intelligence.)\n\n" |
|
"## Evaluation Requirements\n\n" |
|
"Response **MUST** follow this format:\n\n" |
|
"<think>\n" |
|
f"Analysing text relevance and paralinguistic information **Appropriateness** and reasons for scoring...(less than {think_max_len} words)\n" |
|
"</think>\n\n" |
|
"<score>X</score> (**X is 1, 3, or 5**)\n\n") |
|
else: |
|
prompt_template = ( |
|
"# Dialogue Response Evaluation\n\n" |
|
"**IMPORTANT:** Evaluation must include`<score>` rating.\n\n" |
|
"Listen to the dialogue recording (two sentences, 1-second pause in between). Evaluate the quality of the **second sentence** as a response to the first, focusing on **text relevance** and the **appropriateness** of **Linguistic information (a range of paralinguistic information such as emotion/age/pitch/speed/volume)**.\n" |
|
"**Note:** Focus on evaluating the appropriateness of the second sentence relative to the first, even if the first sentence itself contains contradictory information.\n\n" |
|
"## Scoring Criteria\n\n" |
|
"**1 points**: Text content is irrelevant or incorrect or illogical.(low intelligence)\n" |
|
"**3 points**: Text is relevant, but paralinguistic information is **inappropriate** for the context.(low emotional quotient)\n" |
|
"**5 points**: Text is relevant, and paralinguistic information is **appropriate** for the context, resulting in effective communication.(High intelligence and emotional intelligence.)\n\n" |
|
"## Evaluation Requirements\n\n" |
|
"Response **MUST** follow this format:\n\n" |
|
"<score>X</score> (**X is 1, 3, or 5**)\n\n") |
|
|
|
|
|
processed_obj = { |
|
"id": obj["id"], |
|
"prompt": [{"role": "user", "content": [ |
|
{"type": "audio", "audio": obj["merge_wav"]}, |
|
{"type": "text", "text": prompt_template} |
|
]}], |
|
"solution": obj["gt_score"], |
|
"audio": obj.get("audio", None), |
|
"clean_dialogue": obj.get("clean_dialogue", None) |
|
} |
|
return processed_obj |
|
|
|
|
|
class AudioDataset(Dataset): |
|
def __init__(self, data_dir, sample_rate=16000, is_think=True, think_max_len=50, load_audio=False): |
|
super().__init__() |
|
self.sample_rate = sample_rate |
|
self.data_dir = data_dir |
|
self.is_think = is_think |
|
self.think_max_len = think_max_len |
|
self.load_audio = load_audio |
|
self.metadata = [] |
|
self._load_metadata() |
|
logging.info(f"Loaded metadata for {len(self.metadata)} dialogues from {data_dir}") |
|
|
|
def _load_metadata(self): |
|
for fname in os.listdir(self.data_dir): |
|
if fname.endswith('.json'): |
|
fpath = os.path.join(self.data_dir, fname) |
|
with open(fpath, 'r', encoding='utf8') as f: |
|
try: |
|
json_obj = json.load(f) |
|
except Exception as e: |
|
logging.warning(f"Failed to load {fpath}: {e}") |
|
continue |
|
for dialogue_id, obj in json_obj.items(): |
|
|
|
metadata = { |
|
"id": dialogue_id, |
|
"merge_wav": obj.get("merge_wav", None), |
|
"gt_score": obj.get("gt_score", None), |
|
"clean_dialogue": obj.get("clean_dialogue", None), |
|
"json_path": fpath |
|
} |
|
self.metadata.append(metadata) |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def __getitem__(self, index): |
|
metadata = self.metadata[index] |
|
|
|
|
|
item = { |
|
"id": metadata["id"], |
|
"merge_wav": metadata["merge_wav"], |
|
"gt_score": metadata["gt_score"], |
|
"clean_dialogue": metadata["clean_dialogue"] |
|
} |
|
|
|
|
|
if self.load_audio and metadata["merge_wav"] and os.path.exists(metadata["merge_wav"]): |
|
item["audio"] = _handle_wav(metadata["merge_wav"], self.sample_rate).numpy() |
|
|
|
|
|
return _handle_qa( |
|
item, |
|
is_think=self.is_think, |
|
think_max_len=self.think_max_len |
|
) |