# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from argparse import ArgumentParser from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline from ..utils import logging from . import BaseTransformersCLICommand logger = logging.get_logger(__name__) # pylint: disable=invalid-name def try_infer_format_from_ext(path: str): if not path: return "pipe" for ext in PipelineDataFormat.SUPPORTED_FORMATS: if path.endswith(ext): return ext raise Exception( f"Unable to determine file format from file extension {path}. " f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}" ) def run_command_factory(args): nlp = pipeline( task=args.task, model=args.model if args.model else None, config=args.config, tokenizer=args.tokenizer, device=args.device, ) format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format reader = PipelineDataFormat.from_str( format=format, output_path=args.output, input_path=args.input, column=args.column if args.column else nlp.default_input_names, overwrite=args.overwrite, ) return RunCommand(nlp, reader) class RunCommand(BaseTransformersCLICommand): def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): self._nlp = nlp self._reader = reader @staticmethod def register_subcommand(parser: ArgumentParser): run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run") run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.") run_parser.add_argument( "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)" ) run_parser.add_argument( "--column", type=str, help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)", ) run_parser.add_argument( "--format", type=str, default="infer", choices=PipelineDataFormat.SUPPORTED_FORMATS, help="Input format to read from", ) run_parser.add_argument( "--device", type=int, default=-1, help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", ) run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.") run_parser.set_defaults(func=run_command_factory) def run(self): nlp, outputs = self._nlp, [] for entry in self._reader: output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry) if isinstance(output, dict): outputs.append(output) else: outputs += output # Saving data if self._nlp.binary_output: binary_path = self._reader.save_binary(outputs) logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}") else: self._reader.save(outputs)