EdgeTA / data /datasets /machine_translation /universal_asc_19_domains.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
12.9 kB
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from typing import Dict, List, Optional, Any
from utils.common.data_record import read_json
from ..sentiment_classification.global_bert_tokenizer import get_tokenizer
# 自定义数据集类
class UniversalASC19DomainsTranslationDataset(Dataset):
def __init__(self, root_dir: str, split: str, transform: Any,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
assert transform is None
self.tokenizer = get_tokenizer() # 传入tokenizer对象
self.srcs = []
self.tgts = []
self.max_length = None # 设置文本的最大长度
json_file_path = os.path.join(root_dir, f'{split if split != "val" else "dev"}.json.translate_data')
anns = read_json(json_file_path)
# label_map = {'-': 0, '+': 1, 'negative': 0, 'positive': 1}
# ignore_cls_indexes = [classes.index(c) for c in ignore_classes]
for info in anns:
self.srcs += [info['src']]
self.tgts += [info['dst']]
def __len__(self):
return len(self.srcs)
def __getitem__(self, idx):
src = self.srcs[idx]
tgt = self.tgts[idx]
encoded_src = self.tokenizer(
src, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
)
encoded_tgt = self.tokenizer(
tgt, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
)
x = {key: tensor.squeeze(0) for key, tensor in encoded_src.items()}
y = encoded_tgt['input_ids'][0]
y = torch.LongTensor([(int(l) if l != self.tokenizer.pad_token_id else -100) for l in y])
return x, y
from ..ab_dataset import ABDataset
from ..registery import dataset_register
@dataset_register(
name='HL5Domains-ApexAD2600Progressive-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class HL5Domains_ApexAD2600Progressive(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='HL5Domains-CanonG3-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class HL5Domains_CanonG3(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='HL5Domains-CreativeLabsNomadJukeboxZenXtra40GB-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class HL5Domains_CreativeLabsNomadJukeboxZenXtra40GB(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='HL5Domains-NikonCoolpix4300-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class HL5Domains_NikonCoolpix4300(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='HL5Domains-Nokia6610-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class HL5Domains_Nokia6610(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Liu3Domains-Computer-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Liu3Domains_Computer(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Liu3Domains-Router-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Liu3Domains_Router(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Liu3Domains-Speaker-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Liu3Domains_Speaker(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
# import os
# for domain in os.listdir('/data/zql/datasets/nlp_asc_19_domains/dat/absa/Bing9Domains/asc'):
# print(f"""
# @dataset_register(
# name='Ding9Domains-{domain}',
# classes=['unknown'],
# task_type='Machine Translation',
# object_type='Generic',
# class_aliases=[],
# shift_type=None
# )
# class Ding9Domains_{domain}(ABDataset):
# def create_dataset(self, root_dir: str, split: str, transform,
# classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
# return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
# """)
@dataset_register(
name='Ding9Domains-DiaperChamp-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_DiaperChamp(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-Norton-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_Norton(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-LinksysRouter-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_LinksysRouter(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-MicroMP3-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_MicroMP3(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-Nokia6600-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_Nokia6600(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-CanonPowerShotSD500-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_CanonPowerShotSD500(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-ipod-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_ipod(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-HitachiRouter-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_HitachiRouter(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='Ding9Domains-CanonS100-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class Ding9Domains_CanonS100(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='SemEval-Laptop-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class SemEval_Laptop(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)
@dataset_register(
name='SemEval-Rest-Tr',
classes=['unknown'],
task_type='Machine Translation',
object_type='Generic',
class_aliases=[],
shift_type=None
)
class SemEval_Rest(ABDataset):
def create_dataset(self, root_dir: str, split: str, transform,
classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
return UniversalASC19DomainsTranslationDataset(root_dir, split, transform, classes, ignore_classes, idx_map)