File size: 8,205 Bytes
443d045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict
from gradio import components
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from transformers import pipelines
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
try:
import transformers
from transformers import pipelines
except ImportError:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
)
if not isinstance(pipeline, pipelines.base.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
pipeline, pipelines.audio_classification.AudioClassificationPipeline
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
),
"outputs": components.Label(label="Class"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
pipeline,
pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input"
),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
}
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Dataframe(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0],
}
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
pipeline, pipelines.fill_mask.FillMaskPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
pipeline, pipelines.image_classification.ImageClassificationPipeline
):
pipeline_info = {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Label(type="confidences", label="Classification"),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
pipeline, pipelines.question_answering.QuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(lines=7, label="Context"),
components.Textbox(label="Question"),
],
"outputs": [
components.Textbox(label="Answer"),
components.Label(label="Score"),
],
"preprocess": lambda c, q: {"context": c, "question": q},
"postprocess": lambda r: (r["answer"], r["score"]),
}
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
pipeline, pipelines.text2text_generation.SummarizationPipeline
):
pipeline_info = {
"inputs": components.Textbox(lines=7, label="Input"),
"outputs": components.Textbox(label="Summary"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0]["summary_text"],
}
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
pipeline, pipelines.text_classification.TextClassificationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
pipeline, pipelines.text_generation.TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda x: {"text_inputs": x},
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "TranslationPipeline") and isinstance(
pipeline, pipelines.text2text_generation.TranslationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Translation"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["translation_text"],
}
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Generated Text"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(label="Input"),
components.Textbox(label="Possible class names (" "comma-separated)"),
components.Checkbox(label="Allow multiple true classes"),
],
"outputs": components.Label(label="Classification"),
"preprocess": lambda i, c, m: {
"sequences": i,
"candidate_labels": c,
"multi_label": m,
},
"postprocess": lambda r: {
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
},
}
else:
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
# define the function that will be called by the Interface
def fn(*params):
data = pipeline_info["preprocess"](*params)
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data)
output = pipeline_info["postprocess"](data)
return output
interface_info = pipeline_info.copy()
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = pipeline.model.__class__.__name__
return interface_info
|