| import asyncio |
| import logging |
| from concurrent.futures import Executor, ProcessPoolExecutor |
| from datetime import datetime |
| from functools import partial |
| from multiprocessing import freeze_support |
| from typing import Set, Tuple |
|
|
| try: |
| from aiohttp import web |
|
|
| from .middlewares import cors |
| except ImportError as ie: |
| raise ImportError( |
| f"aiohttp dependency is not installed: {ie}. " |
| + "Please re-install black with the '[d]' extra install " |
| + "to obtain aiohttp_cors: `pip install black[d]`" |
| ) from None |
|
|
| import click |
|
|
| import black |
| from _black_version import version as __version__ |
| from black.concurrency import maybe_install_uvloop |
|
|
| |
| _stop_signal = asyncio.Event() |
|
|
| |
| PROTOCOL_VERSION_HEADER = "X-Protocol-Version" |
| LINE_LENGTH_HEADER = "X-Line-Length" |
| PYTHON_VARIANT_HEADER = "X-Python-Variant" |
| SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line" |
| SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" |
| SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma" |
| PREVIEW = "X-Preview" |
| FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" |
| DIFF_HEADER = "X-Diff" |
|
|
| BLACK_HEADERS = [ |
| PROTOCOL_VERSION_HEADER, |
| LINE_LENGTH_HEADER, |
| PYTHON_VARIANT_HEADER, |
| SKIP_SOURCE_FIRST_LINE, |
| SKIP_STRING_NORMALIZATION_HEADER, |
| SKIP_MAGIC_TRAILING_COMMA, |
| PREVIEW, |
| FAST_OR_SAFE_HEADER, |
| DIFF_HEADER, |
| ] |
|
|
| |
| BLACK_VERSION_HEADER = "X-Black-Version" |
|
|
|
|
| class InvalidVariantHeader(Exception): |
| pass |
|
|
|
|
| @click.command(context_settings={"help_option_names": ["-h", "--help"]}) |
| @click.option( |
| "--bind-host", type=str, help="Address to bind the server to.", default="localhost" |
| ) |
| @click.option("--bind-port", type=int, help="Port to listen on", default=45484) |
| @click.version_option(version=black.__version__) |
| def main(bind_host: str, bind_port: int) -> None: |
| logging.basicConfig(level=logging.INFO) |
| app = make_app() |
| ver = black.__version__ |
| black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}") |
| web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None) |
|
|
|
|
| def make_app() -> web.Application: |
| app = web.Application( |
| middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))] |
| ) |
| executor = ProcessPoolExecutor() |
| app.add_routes([web.post("/", partial(handle, executor=executor))]) |
| return app |
|
|
|
|
| async def handle(request: web.Request, executor: Executor) -> web.Response: |
| headers = {BLACK_VERSION_HEADER: __version__} |
| try: |
| if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1": |
| return web.Response( |
| status=501, text="This server only supports protocol version 1" |
| ) |
| try: |
| line_length = int( |
| request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH) |
| ) |
| except ValueError: |
| return web.Response(status=400, text="Invalid line length header value") |
|
|
| if PYTHON_VARIANT_HEADER in request.headers: |
| value = request.headers[PYTHON_VARIANT_HEADER] |
| try: |
| pyi, versions = parse_python_variant_header(value) |
| except InvalidVariantHeader as e: |
| return web.Response( |
| status=400, |
| text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}", |
| ) |
| else: |
| pyi = False |
| versions = set() |
|
|
| skip_string_normalization = bool( |
| request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) |
| ) |
| skip_magic_trailing_comma = bool( |
| request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False) |
| ) |
| skip_source_first_line = bool( |
| request.headers.get(SKIP_SOURCE_FIRST_LINE, False) |
| ) |
| preview = bool(request.headers.get(PREVIEW, False)) |
| fast = False |
| if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": |
| fast = True |
| mode = black.FileMode( |
| target_versions=versions, |
| is_pyi=pyi, |
| line_length=line_length, |
| skip_source_first_line=skip_source_first_line, |
| string_normalization=not skip_string_normalization, |
| magic_trailing_comma=not skip_magic_trailing_comma, |
| preview=preview, |
| ) |
| req_bytes = await request.content.read() |
| charset = request.charset if request.charset is not None else "utf8" |
| req_str = req_bytes.decode(charset) |
| then = datetime.utcnow() |
|
|
| header = "" |
| if skip_source_first_line: |
| first_newline_position: int = req_str.find("\n") + 1 |
| header = req_str[:first_newline_position] |
| req_str = req_str[first_newline_position:] |
|
|
| loop = asyncio.get_event_loop() |
| formatted_str = await loop.run_in_executor( |
| executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode) |
| ) |
|
|
| |
| if req_str[req_str.find("\n") - 1] == "\r": |
| formatted_str = formatted_str.replace("\n", "\r\n") |
| |
| if formatted_str == req_str: |
| raise black.NothingChanged |
|
|
| |
| req_str = header + req_str |
| formatted_str = header + formatted_str |
|
|
| |
| only_diff = bool(request.headers.get(DIFF_HEADER, False)) |
| if only_diff: |
| now = datetime.utcnow() |
| src_name = f"In\t{then} +0000" |
| dst_name = f"Out\t{now} +0000" |
| loop = asyncio.get_event_loop() |
| formatted_str = await loop.run_in_executor( |
| executor, |
| partial(black.diff, req_str, formatted_str, src_name, dst_name), |
| ) |
|
|
| return web.Response( |
| content_type=request.content_type, |
| charset=charset, |
| headers=headers, |
| text=formatted_str, |
| ) |
| except black.NothingChanged: |
| return web.Response(status=204, headers=headers) |
| except black.InvalidInput as e: |
| return web.Response(status=400, headers=headers, text=str(e)) |
| except Exception as e: |
| logging.exception("Exception during handling a request") |
| return web.Response(status=500, headers=headers, text=str(e)) |
|
|
|
|
| def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]: |
| if value == "pyi": |
| return True, set() |
| else: |
| versions = set() |
| for version in value.split(","): |
| if version.startswith("py"): |
| version = version[len("py") :] |
| if "." in version: |
| major_str, *rest = version.split(".") |
| else: |
| major_str = version[0] |
| rest = [version[1:]] if len(version) > 1 else [] |
| try: |
| major = int(major_str) |
| if major not in (2, 3): |
| raise InvalidVariantHeader("major version must be 2 or 3") |
| if len(rest) > 0: |
| minor = int(rest[0]) |
| if major == 2: |
| raise InvalidVariantHeader("Python 2 is not supported") |
| else: |
| |
| minor = 7 if major == 2 else 3 |
| version_str = f"PY{major}{minor}" |
| if major == 3 and not hasattr(black.TargetVersion, version_str): |
| raise InvalidVariantHeader(f"3.{minor} is not supported") |
| versions.add(black.TargetVersion[version_str]) |
| except (KeyError, ValueError): |
| raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None |
| return False, versions |
|
|
|
|
| def patched_main() -> None: |
| maybe_install_uvloop() |
| freeze_support() |
| black.patch_click() |
| main() |
|
|
|
|
| if __name__ == "__main__": |
| patched_main() |
|
|