# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from argparse import ArgumentParser, Namespace from typing import Any, List, Optional from ..pipelines import Pipeline, get_supported_tasks, pipeline from ..utils import logging from . import BaseTransformersCLICommand try: from fastapi import Body, FastAPI, HTTPException from fastapi.routing import APIRoute from pydantic import BaseModel from starlette.responses import JSONResponse from uvicorn import run _serve_dependencies_installed = True except (ImportError, AttributeError): BaseModel = object def Body(*x, **y): pass _serve_dependencies_installed = False logger = logging.get_logger("transformers-cli/serving") def serve_command_factory(args: Namespace): """ Factory function used to instantiate serving server from provided command line arguments. Returns: ServeCommand """ nlp = pipeline( task=args.task, model=args.model if args.model else None, config=args.config, tokenizer=args.tokenizer, device=args.device, ) return ServeCommand(nlp, args.host, args.port, args.workers) class ServeModelInfoResult(BaseModel): """ Expose model information """ infos: dict class ServeTokenizeResult(BaseModel): """ Tokenize result model """ tokens: List[str] tokens_ids: Optional[List[int]] class ServeDeTokenizeResult(BaseModel): """ DeTokenize result model """ text: str class ServeForwardResult(BaseModel): """ Forward result model """ output: Any class ServeCommand(BaseTransformersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): """ Register this command to argparse so it's available for the transformer-cli Args: parser: Root parser to register command-specific arguments """ serve_parser = parser.add_parser( "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints." ) serve_parser.add_argument( "--task", type=str, choices=get_supported_tasks(), help="The task to run the pipeline on", ) serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers") serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.") serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.") serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.") serve_parser.add_argument( "--device", type=int, default=-1, help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", ) serve_parser.set_defaults(func=serve_command_factory) def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int): self._pipeline = pipeline self.host = host self.port = port self.workers = workers if not _serve_dependencies_installed: raise RuntimeError( "Using serve command requires FastAPI and uvicorn. " 'Please install transformers with [serving]: pip install "transformers[serving]".' "Or install FastAPI and uvicorn separately." ) else: logger.info(f"Serving model over {host}:{port}") self._app = FastAPI( routes=[ APIRoute( "/", self.model_info, response_model=ServeModelInfoResult, response_class=JSONResponse, methods=["GET"], ), APIRoute( "/tokenize", self.tokenize, response_model=ServeTokenizeResult, response_class=JSONResponse, methods=["POST"], ), APIRoute( "/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, response_class=JSONResponse, methods=["POST"], ), APIRoute( "/forward", self.forward, response_model=ServeForwardResult, response_class=JSONResponse, methods=["POST"], ), ], timeout=600, ) def run(self): run(self._app, host=self.host, port=self.port, workers=self.workers) def model_info(self): return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)): """ Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping. """ try: tokens_txt = self._pipeline.tokenizer.tokenize(text_input) if return_ids: tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt) return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids) else: return ServeTokenizeResult(tokens=tokens_txt) except Exception as e: raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) def detokenize( self, tokens_ids: List[int] = Body(None, embed=True), skip_special_tokens: bool = Body(False, embed=True), cleanup_tokenization_spaces: bool = Body(True, embed=True), ): """ Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids - **skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones. """ try: decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces) return ServeDeTokenizeResult(model="", text=decoded_str) except Exception as e: raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) async def forward(self, inputs=Body(None, embed=True)): """ **inputs**: **attention_mask**: **tokens_type_ids**: """ # Check we don't have empty string if len(inputs) == 0: return ServeForwardResult(output=[], attention=[]) try: # Forward through the model output = self._pipeline(inputs) return ServeForwardResult(output=output) except Exception as e: raise HTTPException(500, {"error": str(e)})