geored's picture
Upload folder using huggingface_hub
add8f0b verified
raw
history blame
3.26 kB
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING
from gradio.pipelines_utils import (
handle_diffusers_pipeline,
handle_transformers_pipeline,
)
if TYPE_CHECKING:
import diffusers
import transformers
def load_from_pipeline(
pipeline: transformers.Pipeline | diffusers.DiffusionPipeline, # type: ignore
) -> dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline or diffusers.DiffusionPipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
if str(type(pipeline).__module__).startswith("transformers.pipelines."):
pipeline_info = handle_transformers_pipeline(pipeline)
elif str(type(pipeline).__module__).startswith("diffusers.pipelines."):
pipeline_info = handle_diffusers_pipeline(pipeline)
else:
raise ValueError(
"pipeline must be a transformers.pipeline or diffusers.pipeline"
)
def fn(*params):
if pipeline_info:
data = pipeline_info["preprocess"](*params)
if str(type(pipeline).__module__).startswith("transformers.pipelines"):
from transformers import pipelines
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data) # type: ignore
# special case for object-detection
# original input image sent to postprocess function
if isinstance(
pipeline,
pipelines.object_detection.ObjectDetectionPipeline,
):
output = pipeline_info["postprocess"](data, params[0])
else:
output = pipeline_info["postprocess"](data)
return output
elif str(type(pipeline).__module__).startswith("diffusers.pipelines"):
data = pipeline(**data) # type: ignore
output = pipeline_info["postprocess"](data)
return output
else:
raise ValueError("pipeline_info can not be None.")
interface_info = pipeline_info.copy() if pipeline_info else {}
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = (
pipeline.model.__class__.__name__
if str(type(pipeline).__module__).startswith("transformers.pipelines")
else pipeline.__class__.__name__
)
return interface_info