|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import ArgumentParser, Namespace |
|
from typing import Any, List, Optional |
|
|
|
from ..pipelines import Pipeline, get_supported_tasks, pipeline |
|
from ..utils import logging |
|
from . import BaseTransformersCLICommand |
|
|
|
|
|
try: |
|
from fastapi import Body, FastAPI, HTTPException |
|
from fastapi.routing import APIRoute |
|
from pydantic import BaseModel |
|
from starlette.responses import JSONResponse |
|
from uvicorn import run |
|
|
|
_serve_dependencies_installed = True |
|
except (ImportError, AttributeError): |
|
BaseModel = object |
|
|
|
def Body(*x, **y): |
|
pass |
|
|
|
_serve_dependencies_installed = False |
|
|
|
|
|
logger = logging.get_logger("transformers-cli/serving") |
|
|
|
|
|
def serve_command_factory(args: Namespace): |
|
""" |
|
Factory function used to instantiate serving server from provided command line arguments. |
|
|
|
Returns: ServeCommand |
|
""" |
|
nlp = pipeline( |
|
task=args.task, |
|
model=args.model if args.model else None, |
|
config=args.config, |
|
tokenizer=args.tokenizer, |
|
device=args.device, |
|
) |
|
return ServeCommand(nlp, args.host, args.port, args.workers) |
|
|
|
|
|
class ServeModelInfoResult(BaseModel): |
|
""" |
|
Expose model information |
|
""" |
|
|
|
infos: dict |
|
|
|
|
|
class ServeTokenizeResult(BaseModel): |
|
""" |
|
Tokenize result model |
|
""" |
|
|
|
tokens: List[str] |
|
tokens_ids: Optional[List[int]] |
|
|
|
|
|
class ServeDeTokenizeResult(BaseModel): |
|
""" |
|
DeTokenize result model |
|
""" |
|
|
|
text: str |
|
|
|
|
|
class ServeForwardResult(BaseModel): |
|
""" |
|
Forward result model |
|
""" |
|
|
|
output: Any |
|
|
|
|
|
class ServeCommand(BaseTransformersCLICommand): |
|
@staticmethod |
|
def register_subcommand(parser: ArgumentParser): |
|
""" |
|
Register this command to argparse so it's available for the transformer-cli |
|
|
|
Args: |
|
parser: Root parser to register command-specific arguments |
|
""" |
|
serve_parser = parser.add_parser( |
|
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints." |
|
) |
|
serve_parser.add_argument( |
|
"--task", |
|
type=str, |
|
choices=get_supported_tasks(), |
|
help="The task to run the pipeline on", |
|
) |
|
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") |
|
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") |
|
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers") |
|
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.") |
|
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.") |
|
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.") |
|
serve_parser.add_argument( |
|
"--device", |
|
type=int, |
|
default=-1, |
|
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", |
|
) |
|
serve_parser.set_defaults(func=serve_command_factory) |
|
|
|
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int): |
|
self._pipeline = pipeline |
|
|
|
self.host = host |
|
self.port = port |
|
self.workers = workers |
|
|
|
if not _serve_dependencies_installed: |
|
raise RuntimeError( |
|
"Using serve command requires FastAPI and uvicorn. " |
|
'Please install transformers with [serving]: pip install "transformers[serving]". ' |
|
"Or install FastAPI and uvicorn separately." |
|
) |
|
else: |
|
logger.info(f"Serving model over {host}:{port}") |
|
self._app = FastAPI( |
|
routes=[ |
|
APIRoute( |
|
"/", |
|
self.model_info, |
|
response_model=ServeModelInfoResult, |
|
response_class=JSONResponse, |
|
methods=["GET"], |
|
), |
|
APIRoute( |
|
"/tokenize", |
|
self.tokenize, |
|
response_model=ServeTokenizeResult, |
|
response_class=JSONResponse, |
|
methods=["POST"], |
|
), |
|
APIRoute( |
|
"/detokenize", |
|
self.detokenize, |
|
response_model=ServeDeTokenizeResult, |
|
response_class=JSONResponse, |
|
methods=["POST"], |
|
), |
|
APIRoute( |
|
"/forward", |
|
self.forward, |
|
response_model=ServeForwardResult, |
|
response_class=JSONResponse, |
|
methods=["POST"], |
|
), |
|
], |
|
timeout=600, |
|
) |
|
|
|
def run(self): |
|
run(self._app, host=self.host, port=self.port, workers=self.workers) |
|
|
|
def model_info(self): |
|
return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) |
|
|
|
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)): |
|
""" |
|
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to |
|
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer |
|
mapping. |
|
""" |
|
try: |
|
tokens_txt = self._pipeline.tokenizer.tokenize(text_input) |
|
|
|
if return_ids: |
|
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt) |
|
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids) |
|
else: |
|
return ServeTokenizeResult(tokens=tokens_txt) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) |
|
|
|
def detokenize( |
|
self, |
|
tokens_ids: List[int] = Body(None, embed=True), |
|
skip_special_tokens: bool = Body(False, embed=True), |
|
cleanup_tokenization_spaces: bool = Body(True, embed=True), |
|
): |
|
""" |
|
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids - |
|
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**: |
|
Flag indicating to remove all leading/trailing spaces and intermediate ones. |
|
""" |
|
try: |
|
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces) |
|
return ServeDeTokenizeResult(model="", text=decoded_str) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) |
|
|
|
async def forward(self, inputs=Body(None, embed=True)): |
|
""" |
|
**inputs**: **attention_mask**: **tokens_type_ids**: |
|
""" |
|
|
|
|
|
if len(inputs) == 0: |
|
return ServeForwardResult(output=[], attention=[]) |
|
|
|
try: |
|
|
|
output = self._pipeline(inputs) |
|
return ServeForwardResult(output=output) |
|
except Exception as e: |
|
raise HTTPException(500, {"error": str(e)}) |
|
|