gradio / pipelines.py
hd0013's picture
Upload folder using huggingface_hub
8fdc036 verified
raw
history blame contribute delete
No virus
4.39 kB
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING
from gradio.pipelines_utils import (
handle_diffusers_pipeline,
handle_transformers_js_pipeline,
handle_transformers_pipeline,
)
if TYPE_CHECKING:
import diffusers
import transformers
def load_from_pipeline(
pipeline: transformers.Pipeline | diffusers.DiffusionPipeline, # type: ignore
) -> dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline or diffusers.DiffusionPipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
if str(type(pipeline).__module__).startswith("transformers.pipelines."):
pipeline_info = handle_transformers_pipeline(pipeline)
elif str(type(pipeline).__module__).startswith("diffusers.pipelines."):
pipeline_info = handle_diffusers_pipeline(pipeline)
else:
raise ValueError(
"pipeline must be a transformers.pipeline or diffusers.pipeline"
)
def fn(*params):
if pipeline_info:
data = pipeline_info["preprocess"](*params)
if str(type(pipeline).__module__).startswith("transformers.pipelines"):
from transformers import pipelines
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data) # type: ignore
# special case for object-detection
# original input image sent to postprocess function
if isinstance(
pipeline,
pipelines.object_detection.ObjectDetectionPipeline,
):
output = pipeline_info["postprocess"](data, params[0])
else:
output = pipeline_info["postprocess"](data)
return output
elif str(type(pipeline).__module__).startswith("diffusers.pipelines"):
data = pipeline(**data) # type: ignore
output = pipeline_info["postprocess"](data)
return output
else:
raise ValueError("pipeline_info can not be None.")
interface_info = pipeline_info.copy() if pipeline_info else {}
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = (
pipeline.model.config.name_or_path
if str(type(pipeline).__module__).startswith("transformers.pipelines")
else pipeline.__class__.__name__
)
return interface_info
def load_from_js_pipeline(pipeline) -> dict:
if str(type(pipeline).__module__).startswith("transformers_js_py."):
pipeline_info = handle_transformers_js_pipeline(pipeline)
else:
raise ValueError("pipeline must be a transformers_js_py's pipeline")
async def fn(*params):
preprocess = pipeline_info["preprocess"]
postprocess = pipeline_info["postprocess"]
postprocess_takes_inputs = pipeline_info.get("postprocess_takes_inputs", False)
preprocessed_params = preprocess(*params) if preprocess else params
pipeline_output = await pipeline(*preprocessed_params)
postprocessed_output = (
postprocess(pipeline_output, *(params if postprocess_takes_inputs else ()))
if postprocess
else pipeline_output
)
return postprocessed_output
interface_info = {
"fn": fn,
"inputs": pipeline_info["inputs"],
"outputs": pipeline_info["outputs"],
"title": f"{pipeline.task} ({pipeline.model.config._name_or_path})",
}
return interface_info