| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import unittest | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from transformers import ( | 
					
					
						
						| 
							 | 
						    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | 
					
					
						
						| 
							 | 
						    TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | 
					
					
						
						| 
							 | 
						    Pipeline, | 
					
					
						
						| 
							 | 
						    ZeroShotClassificationPipeline, | 
					
					
						
						| 
							 | 
						    pipeline, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from transformers.testing_utils import ( | 
					
					
						
						| 
							 | 
						    is_pipeline_test, | 
					
					
						
						| 
							 | 
						    is_torch_available, | 
					
					
						
						| 
							 | 
						    nested_simplify, | 
					
					
						
						| 
							 | 
						    require_tf, | 
					
					
						
						| 
							 | 
						    require_torch, | 
					
					
						
						| 
							 | 
						    slow, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .test_pipelines_common import ANY | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if is_torch_available(): | 
					
					
						
						| 
							 | 
						    import torch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@is_pipeline_test | 
					
					
						
						| 
							 | 
						class ZeroShotClassificationPipelineTests(unittest.TestCase): | 
					
					
						
						| 
							 | 
						    model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | 
					
					
						
						| 
							 | 
						    tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not hasattr(model_mapping, "is_dummy"): | 
					
					
						
						| 
							 | 
						        model_mapping = {config: model for config, model in model_mapping.items() if config.__name__ not in _TO_SKIP} | 
					
					
						
						| 
							 | 
						    if not hasattr(tf_model_mapping, "is_dummy"): | 
					
					
						
						| 
							 | 
						        tf_model_mapping = { | 
					
					
						
						| 
							 | 
						            config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_test_pipeline( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        model, | 
					
					
						
						| 
							 | 
						        tokenizer=None, | 
					
					
						
						| 
							 | 
						        image_processor=None, | 
					
					
						
						| 
							 | 
						        feature_extractor=None, | 
					
					
						
						| 
							 | 
						        processor=None, | 
					
					
						
						| 
							 | 
						        torch_dtype="float32", | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        classifier = ZeroShotClassificationPipeline( | 
					
					
						
						| 
							 | 
						            model=model, | 
					
					
						
						| 
							 | 
						            tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						            feature_extractor=feature_extractor, | 
					
					
						
						| 
							 | 
						            image_processor=image_processor, | 
					
					
						
						| 
							 | 
						            processor=processor, | 
					
					
						
						| 
							 | 
						            torch_dtype=torch_dtype, | 
					
					
						
						| 
							 | 
						            candidate_labels=["polics", "health"], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return classifier, ["Who are you voting for in 2020?", "My stomach hurts."] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def run_pipeline_test(self, classifier, _): | 
					
					
						
						| 
							 | 
						        outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics") | 
					
					
						
						| 
							 | 
						        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        outputs = classifier("Who are you voting for in 2020?", ["politics"]) | 
					
					
						
						| 
							 | 
						        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"]) | 
					
					
						
						| 
							 | 
						        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics, public health") | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]} | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics", "public health"]) | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]} | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels="politics", hypothesis_template="This text is about {}" | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        outputs = classifier(["I am happy"], ["positive", "negative"]) | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            outputs, | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]} | 
					
					
						
						| 
							 | 
						                for i in range(1) | 
					
					
						
						| 
							 | 
						            ], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"]) | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            outputs, | 
					
					
						
						| 
							 | 
						            [ | 
					
					
						
						| 
							 | 
						                {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]} | 
					
					
						
						| 
							 | 
						                for i in range(2) | 
					
					
						
						| 
							 | 
						            ], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with self.assertRaises(ValueError): | 
					
					
						
						| 
							 | 
						            classifier("", candidate_labels="politics") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with self.assertRaises(TypeError): | 
					
					
						
						| 
							 | 
						            classifier(None, candidate_labels="politics") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with self.assertRaises(ValueError): | 
					
					
						
						| 
							 | 
						            classifier("Who are you voting for in 2020?", candidate_labels="") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with self.assertRaises(TypeError): | 
					
					
						
						| 
							 | 
						            classifier("Who are you voting for in 2020?", candidate_labels=None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with self.assertRaises(ValueError): | 
					
					
						
						| 
							 | 
						            classifier( | 
					
					
						
						| 
							 | 
						                "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                candidate_labels="politics", | 
					
					
						
						| 
							 | 
						                hypothesis_template="Not formatting template", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with self.assertRaises(AttributeError): | 
					
					
						
						| 
							 | 
						            classifier( | 
					
					
						
						| 
							 | 
						                "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                candidate_labels="politics", | 
					
					
						
						| 
							 | 
						                hypothesis_template=None, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.run_entailment_id(classifier) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def run_entailment_id(self, zero_shot_classifier: Pipeline): | 
					
					
						
						| 
							 | 
						        config = zero_shot_classifier.model.config | 
					
					
						
						| 
							 | 
						        original_label2id = config.label2id | 
					
					
						
						| 
							 | 
						        original_entailment = zero_shot_classifier.entailment_id | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2} | 
					
					
						
						| 
							 | 
						        self.assertEqual(zero_shot_classifier.entailment_id, -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2} | 
					
					
						
						| 
							 | 
						        self.assertEqual(zero_shot_classifier.entailment_id, 0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        config.label2id = {"ENTAIL": 0, "NON-ENTAIL": 1} | 
					
					
						
						| 
							 | 
						        self.assertEqual(zero_shot_classifier.entailment_id, 0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0} | 
					
					
						
						| 
							 | 
						        self.assertEqual(zero_shot_classifier.entailment_id, 2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        zero_shot_classifier.model.config.label2id = original_label2id | 
					
					
						
						| 
							 | 
						        self.assertEqual(original_entailment, zero_shot_classifier.entailment_id) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @require_torch | 
					
					
						
						| 
							 | 
						    def test_truncation(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", | 
					
					
						
						| 
							 | 
						            model="sshleifer/tiny-distilbert-base-cased-distilled-squad", | 
					
					
						
						| 
							 | 
						            framework="pt", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?" * 100, candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @require_torch | 
					
					
						
						| 
							 | 
						    def test_small_model_pt(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", | 
					
					
						
						| 
							 | 
						            model="sshleifer/tiny-distilbert-base-cased-distilled-squad", | 
					
					
						
						| 
							 | 
						            framework="pt", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                "labels": ["science", "public health", "politics"], | 
					
					
						
						| 
							 | 
						                "scores": [0.333, 0.333, 0.333], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @require_torch | 
					
					
						
						| 
							 | 
						    def test_small_model_pt_fp16(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", | 
					
					
						
						| 
							 | 
						            model="sshleifer/tiny-distilbert-base-cased-distilled-squad", | 
					
					
						
						| 
							 | 
						            framework="pt", | 
					
					
						
						| 
							 | 
						            torch_dtype=torch.float16, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                "labels": ["science", "public health", "politics"], | 
					
					
						
						| 
							 | 
						                "scores": [0.333, 0.333, 0.333], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @require_torch | 
					
					
						
						| 
							 | 
						    def test_small_model_pt_bf16(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", | 
					
					
						
						| 
							 | 
						            model="sshleifer/tiny-distilbert-base-cased-distilled-squad", | 
					
					
						
						| 
							 | 
						            framework="pt", | 
					
					
						
						| 
							 | 
						            torch_dtype=torch.bfloat16, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                "labels": ["science", "public health", "politics"], | 
					
					
						
						| 
							 | 
						                "scores": [0.333, 0.333, 0.333], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @require_tf | 
					
					
						
						| 
							 | 
						    def test_small_model_tf(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", | 
					
					
						
						| 
							 | 
						            model="sshleifer/tiny-distilbert-base-cased-distilled-squad", | 
					
					
						
						| 
							 | 
						            framework="tf", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                "labels": ["science", "public health", "politics"], | 
					
					
						
						| 
							 | 
						                "scores": [0.333, 0.333, 0.333], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @slow | 
					
					
						
						| 
							 | 
						    @require_torch | 
					
					
						
						| 
							 | 
						    def test_large_model_pt(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", model="FacebookAI/roberta-large-mnli", framework="pt" | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                "labels": ["politics", "public health", "science"], | 
					
					
						
						| 
							 | 
						                "scores": [0.976, 0.015, 0.009], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks" | 
					
					
						
						| 
							 | 
						            " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder" | 
					
					
						
						| 
							 | 
						            " through an attention mechanism. We propose a new simple network architecture, the Transformer, based" | 
					
					
						
						| 
							 | 
						            " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two" | 
					
					
						
						| 
							 | 
						            " machine translation tasks show these models to be superior in quality while being more parallelizable" | 
					
					
						
						| 
							 | 
						            " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014" | 
					
					
						
						| 
							 | 
						            " English-to-German translation task, improving over the existing best results, including ensembles by" | 
					
					
						
						| 
							 | 
						            " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new" | 
					
					
						
						| 
							 | 
						            " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small" | 
					
					
						
						| 
							 | 
						            " fraction of the training costs of the best models from the literature. We show that the Transformer" | 
					
					
						
						| 
							 | 
						            " generalizes well to other tasks by applying it successfully to English constituency parsing both with" | 
					
					
						
						| 
							 | 
						            " large and limited training data.", | 
					
					
						
						| 
							 | 
						            candidate_labels=["machine learning", "statistics", "translation", "vision"], | 
					
					
						
						| 
							 | 
						            multi_label=True, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": ( | 
					
					
						
						| 
							 | 
						                    "The dominant sequence transduction models are based on complex recurrent or convolutional neural" | 
					
					
						
						| 
							 | 
						                    " networks in an encoder-decoder configuration. The best performing models also connect the" | 
					
					
						
						| 
							 | 
						                    " encoder and decoder through an attention mechanism. We propose a new simple network" | 
					
					
						
						| 
							 | 
						                    " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence" | 
					
					
						
						| 
							 | 
						                    " and convolutions entirely. Experiments on two machine translation tasks show these models to be" | 
					
					
						
						| 
							 | 
						                    " superior in quality while being more parallelizable and requiring significantly less time to" | 
					
					
						
						| 
							 | 
						                    " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task," | 
					
					
						
						| 
							 | 
						                    " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014" | 
					
					
						
						| 
							 | 
						                    " English-to-French translation task, our model establishes a new single-model state-of-the-art" | 
					
					
						
						| 
							 | 
						                    " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training" | 
					
					
						
						| 
							 | 
						                    " costs of the best models from the literature. We show that the Transformer generalizes well to" | 
					
					
						
						| 
							 | 
						                    " other tasks by applying it successfully to English constituency parsing both with large and" | 
					
					
						
						| 
							 | 
						                    " limited training data." | 
					
					
						
						| 
							 | 
						                ), | 
					
					
						
						| 
							 | 
						                "labels": ["translation", "machine learning", "vision", "statistics"], | 
					
					
						
						| 
							 | 
						                "scores": [0.817, 0.713, 0.018, 0.018], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @slow | 
					
					
						
						| 
							 | 
						    @require_tf | 
					
					
						
						| 
							 | 
						    def test_large_model_tf(self): | 
					
					
						
						| 
							 | 
						        zero_shot_classifier = pipeline( | 
					
					
						
						| 
							 | 
						            "zero-shot-classification", model="FacebookAI/roberta-large-mnli", framework="tf" | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": "Who are you voting for in 2020?", | 
					
					
						
						| 
							 | 
						                "labels": ["politics", "public health", "science"], | 
					
					
						
						| 
							 | 
						                "scores": [0.976, 0.015, 0.009], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        outputs = zero_shot_classifier( | 
					
					
						
						| 
							 | 
						            "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks" | 
					
					
						
						| 
							 | 
						            " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder" | 
					
					
						
						| 
							 | 
						            " through an attention mechanism. We propose a new simple network architecture, the Transformer, based" | 
					
					
						
						| 
							 | 
						            " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two" | 
					
					
						
						| 
							 | 
						            " machine translation tasks show these models to be superior in quality while being more parallelizable" | 
					
					
						
						| 
							 | 
						            " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014" | 
					
					
						
						| 
							 | 
						            " English-to-German translation task, improving over the existing best results, including ensembles by" | 
					
					
						
						| 
							 | 
						            " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new" | 
					
					
						
						| 
							 | 
						            " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small" | 
					
					
						
						| 
							 | 
						            " fraction of the training costs of the best models from the literature. We show that the Transformer" | 
					
					
						
						| 
							 | 
						            " generalizes well to other tasks by applying it successfully to English constituency parsing both with" | 
					
					
						
						| 
							 | 
						            " large and limited training data.", | 
					
					
						
						| 
							 | 
						            candidate_labels=["machine learning", "statistics", "translation", "vision"], | 
					
					
						
						| 
							 | 
						            multi_label=True, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.assertEqual( | 
					
					
						
						| 
							 | 
						            nested_simplify(outputs), | 
					
					
						
						| 
							 | 
						            { | 
					
					
						
						| 
							 | 
						                "sequence": ( | 
					
					
						
						| 
							 | 
						                    "The dominant sequence transduction models are based on complex recurrent or convolutional neural" | 
					
					
						
						| 
							 | 
						                    " networks in an encoder-decoder configuration. The best performing models also connect the" | 
					
					
						
						| 
							 | 
						                    " encoder and decoder through an attention mechanism. We propose a new simple network" | 
					
					
						
						| 
							 | 
						                    " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence" | 
					
					
						
						| 
							 | 
						                    " and convolutions entirely. Experiments on two machine translation tasks show these models to be" | 
					
					
						
						| 
							 | 
						                    " superior in quality while being more parallelizable and requiring significantly less time to" | 
					
					
						
						| 
							 | 
						                    " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task," | 
					
					
						
						| 
							 | 
						                    " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014" | 
					
					
						
						| 
							 | 
						                    " English-to-French translation task, our model establishes a new single-model state-of-the-art" | 
					
					
						
						| 
							 | 
						                    " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training" | 
					
					
						
						| 
							 | 
						                    " costs of the best models from the literature. We show that the Transformer generalizes well to" | 
					
					
						
						| 
							 | 
						                    " other tasks by applying it successfully to English constituency parsing both with large and" | 
					
					
						
						| 
							 | 
						                    " limited training data." | 
					
					
						
						| 
							 | 
						                ), | 
					
					
						
						| 
							 | 
						                "labels": ["translation", "machine learning", "vision", "statistics"], | 
					
					
						
						| 
							 | 
						                "scores": [0.817, 0.713, 0.018, 0.018], | 
					
					
						
						| 
							 | 
						            }, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 |