# -*- coding: utf-8 -*- # @Time : 2021/12/6 3:35 下午 # @Author : JianingWang # @File : __init__.py # from models.chid_mlm import BertForChidMLM from models.multiple_choice.duma import BertDUMAForMultipleChoice, AlbertDUMAForMultipleChoice, MegatronDumaForMultipleChoice from models.span_extraction.global_pointer import BertForEffiGlobalPointer, RobertaForEffiGlobalPointer, RoformerForEffiGlobalPointer, MegatronForEffiGlobalPointer from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, BertTokenizer, \ AutoModelForQuestionAnswering, AutoModelForCausalLM from transformers import AutoTokenizer from transformers.models.roformer import RoFormerTokenizer from transformers.models.bert import BertTokenizerFast, BertForTokenClassification, BertTokenizer from transformers.models.roberta.tokenization_roberta import RobertaTokenizer from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.bart.tokenization_bart import BartTokenizer from transformers.models.t5.tokenization_t5 import T5Tokenizer from transformers.models.plbart.tokenization_plbart import PLBartTokenizer # from models.deberta import DebertaV2ForMultipleChoice, DebertaForMultipleChoice # from models.fengshen.models.longformer import LongformerForMultipleChoice from models.kg import BertForPretrainWithKG, BertForPretrainWithKGV2 from models.language_modeling.mlm import BertForMaskedLM, RobertaForMaskedLM, AlbertForMaskedLM, RoFormerForMaskedLM # from models.sequence_classification.classification import build_cls_model from models.multiple_choice.multiple_choice_tag import BertForTagMultipleChoice, RoFormerForTagMultipleChoice, MegatronBertForTagMultipleChoice from models.multiple_choice.multiple_choice import MegatronBertForMultipleChoice, MegatronBertRDropForMultipleChoice from models.semeval7 import DebertaV2ForSemEval7MultiTask from models.sequence_matching.fusion_siamese import BertForFusionSiamese, BertForWSC # from roformer import RoFormerForTokenClassification, RoFormerForSequenceClassification from models.fewshot_learning.span_proto import SpanProto from models.fewshot_learning.token_proto import TokenProto from models.sequence_labeling.head_token_cls import ( BertSoftmaxForSequenceLabeling, BertCrfForSequenceLabeling, RobertaSoftmaxForSequenceLabeling, RobertaCrfForSequenceLabeling, AlbertSoftmaxForSequenceLabeling, AlbertCrfForSequenceLabeling, MegatronBertSoftmaxForSequenceLabeling, MegatronBertCrfForSequenceLabeling, ) from models.span_extraction.span_for_ner import BertSpanForNer, RobertaSpanForNer, AlbertSpanForNer, MegatronBertSpanForNer from models.language_modeling.mlm import BertForMaskedLM from models.language_modeling.kpplm import BertForWikiKGPLM, RoBertaKPPLMForProcessedWikiKGPLM, DeBertaKPPLMForProcessedWikiKGPLM from models.language_modeling.causal_lm import GPT2ForCausalLM from models.sequence_classification.head_cls import ( BertForSequenceClassification, BertPrefixForSequenceClassification, BertPtuningForSequenceClassification, BertAdapterForSequenceClassification, RobertaForSequenceClassification, RobertaPrefixForSequenceClassification, RobertaPtuningForSequenceClassification,RobertaAdapterForSequenceClassification, BartForSequenceClassification, GPT2ForSequenceClassification ) from models.sequence_classification.masked_prompt_cls import ( PromptBertForSequenceClassification, PromptBertPtuningForSequenceClassification, PromptBertPrefixForSequenceClassification, PromptBertAdapterForSequenceClassification, PromptRobertaForSequenceClassification, PromptRobertaPtuningForSequenceClassification, PromptRobertaPrefixForSequenceClassification, PromptRobertaAdapterForSequenceClassification ) from models.sequence_classification.causal_prompt_cls import PromptGPT2ForSequenceClassification from models.code.code_classification import ( RobertaForCodeClassification, CodeBERTForCodeClassification, GraphCodeBERTForCodeClassification, PLBARTForCodeClassification, CodeT5ForCodeClassification ) from models.code.code_generation import ( PLBARTForCodeGeneration ) from models.reinforcement_learning.actor import CausalActor from models.reinforcement_learning.critic import AutoModelCritic from models.reinforcement_learning.reward_model import ( RobertaForReward, GPT2ForReward ) # Models for pre-training PRETRAIN_MODEL_CLASSES = { "mlm": { "bert": BertForMaskedLM, "roberta": RobertaForMaskedLM, "albert": AlbertForMaskedLM, "roformer": RoFormerForMaskedLM, }, "auto_mlm": AutoModelForMaskedLM, "causal_lm": { "gpt2": GPT2ForCausalLM, "bart": None, "t5": None, "llama": None }, "auto_causal_lm": AutoModelForCausalLM } CLASSIFICATION_MODEL_CLASSES = { "auto_cls": AutoModelForSequenceClassification, # huggingface cls "classification": AutoModelForSequenceClassification, # huggingface cls "head_cls": { "bert": BertForSequenceClassification, "roberta": RobertaForSequenceClassification, "bart": BartForSequenceClassification, "gpt2": GPT2ForSequenceClassification }, # use standard fine-tuning head for cls, e.g., bert+mlp "head_prefix_cls": { "bert": BertPrefixForSequenceClassification, "roberta": RobertaPrefixForSequenceClassification, }, # use standard fine-tuning head with prefix-tuning technique for cls, e.g., bert+mlp "head_ptuning_cls": { "bert": BertPtuningForSequenceClassification, "roberta": RobertaPtuningForSequenceClassification, }, # use standard fine-tuning head with p-tuning technique for cls, e.g., bert+mlp "head_adapter_cls": { "bert": BertAdapterForSequenceClassification, "roberta": RobertaAdapterForSequenceClassification, }, # use standard fine-tuning head with adapter-tuning technique for cls, e.g., bert+mlp "masked_prompt_cls": { "bert": PromptBertForSequenceClassification, "roberta": PromptRobertaForSequenceClassification, # "deberta": PromptDebertaForSequenceClassification, # "deberta-v2": PromptDebertav2ForSequenceClassification, }, # use masked lm head technique for prompt-based cls, e.g., bert+mlm "masked_prompt_prefix_cls": { "bert": PromptBertPrefixForSequenceClassification, "roberta": PromptRobertaPrefixForSequenceClassification, # "deberta": PromptDebertaPrefixForSequenceClassification, # "deberta-v2": PromptDebertav2PrefixForSequenceClassification, }, # use masked lm head with prefix-tuning technique for prompt-based cls, e.g., bert+mlm "masked_prompt_ptuning_cls": { "bert": PromptBertPtuningForSequenceClassification, "roberta": PromptRobertaPtuningForSequenceClassification, # "deberta": PromptDebertaPtuningForSequenceClassification, # "deberta-v2": PromptDebertav2PtuningForSequenceClassification, }, # use masked lm head with p-tuning technique for prompt-based cls, e.g., bert+mlm "masked_prompt_adapter_cls": { "bert": PromptBertAdapterForSequenceClassification, "roberta": PromptRobertaAdapterForSequenceClassification, }, # use masked lm head with adapter-tuning technique for prompt-based cls, e.g., bert+mlm "causal_prompt_cls": { "gpt2": PromptGPT2ForSequenceClassification, "bart": None, "t5": None, }, # use causal lm head for prompt-tuning, e.g., gpt2+lm } TOKEN_CLASSIFICATION_MODEL_CLASSES = { "auto_token_cls": AutoModelForTokenClassification, "head_softmax_token_cls": { "bert": BertSoftmaxForSequenceLabeling, "roberta": RobertaSoftmaxForSequenceLabeling, "albert": AlbertSoftmaxForSequenceLabeling, "megatron": MegatronBertSoftmaxForSequenceLabeling, }, "head_crf_token_cls": { "bert": BertCrfForSequenceLabeling, "roberta": RobertaCrfForSequenceLabeling, "albert": AlbertCrfForSequenceLabeling, "megatron": MegatronBertCrfForSequenceLabeling, } } SPAN_EXTRACTION_MODEL_CLASSES = { "global_pointer": { "bert": BertForEffiGlobalPointer, "roberta": RobertaForEffiGlobalPointer, "roformer": RoformerForEffiGlobalPointer, "megatronbert": MegatronForEffiGlobalPointer }, } FEWSHOT_MODEL_CLASSES = { "sequence_proto": None, "span_proto": SpanProto, "token_proto": TokenProto, } CODE_MODEL_CLASSES = { "code_cls": { "roberta": RobertaForCodeClassification, "codebert": CodeBERTForCodeClassification, "graphcodebert": GraphCodeBERTForCodeClassification, "codet5": CodeT5ForCodeClassification, "plbart": PLBARTForCodeClassification, }, "code_generation": { # "roberta": RobertaForCodeGeneration, # "codebert": BertForCodeGeneration, # "graphcodebert": BertForCodeGeneration, # "codet5": T5ForCodeGeneration, "plbart": PLBARTForCodeGeneration, }, } REINFORCEMENT_MODEL_CLASSES = { "causal_actor": CausalActor, "auto_critic": AutoModelCritic, "rl_reward": { "roberta": RobertaForReward, "gpt2": GPT2ForReward, "gpt-neo": None, "opt": None, "llama": None, } } # task_type 负责对应model类型 OTHER_MODEL_CLASSES = { # sequence labeling "bert_span_ner": BertSpanForNer, "roberta_span_ner": RobertaSpanForNer, "albert_span_ner": AlbertSpanForNer, "megatronbert_span_ner": MegatronBertSpanForNer, # sequence matching "fusion_siamese": BertForFusionSiamese, # multiple choice "multi_choice": AutoModelForMultipleChoice, "multi_choice_megatron": MegatronBertForMultipleChoice, "multi_choice_megatron_rdrop": MegatronBertRDropForMultipleChoice, "megatron_multi_choice_tag": MegatronBertForTagMultipleChoice, "roformer_multi_choice_tag": RoFormerForTagMultipleChoice, "multi_choice_tag": BertForTagMultipleChoice, "duma": BertDUMAForMultipleChoice, "duma_albert": AlbertDUMAForMultipleChoice, "duma_megatron": MegatronDumaForMultipleChoice, # language modeling # "bert_mlm_acc": BertForMaskedLMWithACC, # "roformer_mlm_acc": RoFormerForMaskedLMWithACC, "bert_pretrain_kg": BertForPretrainWithKG, "bert_pretrain_kg_v2": BertForPretrainWithKGV2, "kpplm_roberta": RoBertaKPPLMForProcessedWikiKGPLM, "kpplm_deberta": DeBertaKPPLMForProcessedWikiKGPLM, # other "clue_wsc": BertForWSC, "semeval7multitask": DebertaV2ForSemEval7MultiTask, # "debertav2_multi_choice": DebertaV2ForMultipleChoice, # "deberta_multi_choice": DebertaForMultipleChoice, # "qa": AutoModelForQuestionAnswering, # "roformer_cls": RoFormerForSequenceClassification, # "roformer_ner": RoFormerForTokenClassification, # "fensheng_multi_choice": LongformerForMultipleChoice, # "chid_mlm": BertForChidMLM, } # MODEL_CLASSES = dict(list(PRETRAIN_MODEL_CLASSES.items()) + list(OTHER_MODEL_CLASSES.items())) MODEL_CLASSES_LIST = [ PRETRAIN_MODEL_CLASSES, CLASSIFICATION_MODEL_CLASSES, TOKEN_CLASSIFICATION_MODEL_CLASSES, SPAN_EXTRACTION_MODEL_CLASSES, FEWSHOT_MODEL_CLASSES, CODE_MODEL_CLASSES, REINFORCEMENT_MODEL_CLASSES, OTHER_MODEL_CLASSES, ] MODEL_CLASSES = dict() for model_class in MODEL_CLASSES_LIST: MODEL_CLASSES = dict(list(MODEL_CLASSES.items()) + list(model_class.items())) # model_type 负责对应tokenizer TOKENIZER_CLASSES = { # for natural language processing "auto": AutoTokenizer, "bert": BertTokenizerFast, "roberta": RobertaTokenizer, "wobert": RoFormerTokenizer, "roformer": RoFormerTokenizer, "bigbird": BertTokenizerFast, "erlangshen": BertTokenizerFast, "deberta": BertTokenizer, "roformer_v2": BertTokenizerFast, "gpt2": GPT2Tokenizer, "megatronbert": BertTokenizerFast, "bart": BartTokenizer, "t5": T5Tokenizer, # for programming language processing "codebert": RobertaTokenizer, "graphcodebert": RobertaTokenizer, "codet5": RobertaTokenizer, "plbart": PLBartTokenizer }