"""This module should not be used directly as its API is subject to change. Instead, please use the `gr.Interface.from_pipeline()` function.""" from __future__ import annotations from typing import TYPE_CHECKING from gradio.pipelines_utils import ( handle_diffusers_pipeline, handle_transformers_js_pipeline, handle_transformers_pipeline, ) if TYPE_CHECKING: import diffusers import transformers def load_from_pipeline( pipeline: transformers.Pipeline | diffusers.DiffusionPipeline, # type: ignore ) -> dict: """ Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline or diffusers.DiffusionPipeline. pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface Returns: (dict): a dictionary of kwargs that can be used to construct an Interface object """ if str(type(pipeline).__module__).startswith("transformers.pipelines."): pipeline_info = handle_transformers_pipeline(pipeline) elif str(type(pipeline).__module__).startswith("diffusers.pipelines."): pipeline_info = handle_diffusers_pipeline(pipeline) else: raise ValueError( "pipeline must be a transformers.pipeline or diffusers.pipeline" ) def fn(*params): if pipeline_info: data = pipeline_info["preprocess"](*params) if str(type(pipeline).__module__).startswith("transformers.pipelines"): from transformers import pipelines # special cases that needs to be handled differently if isinstance( pipeline, ( pipelines.text_classification.TextClassificationPipeline, pipelines.text2text_generation.Text2TextGenerationPipeline, pipelines.text2text_generation.TranslationPipeline, ), ): data = pipeline(*data) else: data = pipeline(**data) # type: ignore # special case for object-detection # original input image sent to postprocess function if isinstance( pipeline, pipelines.object_detection.ObjectDetectionPipeline, ): output = pipeline_info["postprocess"](data, params[0]) else: output = pipeline_info["postprocess"](data) return output elif str(type(pipeline).__module__).startswith("diffusers.pipelines"): data = pipeline(**data) # type: ignore output = pipeline_info["postprocess"](data) return output else: raise ValueError("pipeline_info can not be None.") interface_info = pipeline_info.copy() if pipeline_info else {} interface_info["fn"] = fn del interface_info["preprocess"] del interface_info["postprocess"] # define the title/description of the Interface interface_info["title"] = ( pipeline.model.config.name_or_path if str(type(pipeline).__module__).startswith("transformers.pipelines") else pipeline.__class__.__name__ ) return interface_info def load_from_js_pipeline(pipeline) -> dict: if str(type(pipeline).__module__).startswith("transformers_js_py."): pipeline_info = handle_transformers_js_pipeline(pipeline) else: raise ValueError("pipeline must be a transformers_js_py's pipeline") async def fn(*params): preprocess = pipeline_info["preprocess"] postprocess = pipeline_info["postprocess"] postprocess_takes_inputs = pipeline_info.get("postprocess_takes_inputs", False) preprocessed_params = preprocess(*params) if preprocess else params pipeline_output = await pipeline(*preprocessed_params) postprocessed_output = ( postprocess(pipeline_output, *(params if postprocess_takes_inputs else ())) if postprocess else pipeline_output ) return postprocessed_output interface_info = { "fn": fn, "inputs": pipeline_info["inputs"], "outputs": pipeline_info["outputs"], "title": f"{pipeline.task} ({pipeline.model.config._name_or_path})", } return interface_info