|
import base64 |
|
import contextlib |
|
import os |
|
import re |
|
import tempfile |
|
import warnings |
|
from collections.abc import AsyncIterator, Sequence |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import TYPE_CHECKING, Any |
|
from urllib.parse import unquote |
|
|
|
import gradio_client.utils as client_utils |
|
import httpx |
|
from mcp import types |
|
from mcp.server import Server |
|
from mcp.server.sse import SseServerTransport |
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
|
from PIL import Image |
|
from starlette.applications import Starlette |
|
from starlette.requests import Request |
|
from starlette.responses import JSONResponse, Response |
|
from starlette.routing import Mount, Route |
|
from starlette.types import Receive, Scope, Send |
|
|
|
from gradio import processing_utils, route_utils, utils |
|
from gradio.blocks import BlockFunction |
|
from gradio.components import State |
|
from gradio.data_classes import FileData |
|
|
|
if TYPE_CHECKING: |
|
from gradio.blocks import BlockContext, Blocks |
|
from gradio.components import Component |
|
|
|
|
|
DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str( |
|
Path(tempfile.gettempdir()) / "gradio" |
|
) |
|
|
|
|
|
class GradioMCPServer: |
|
""" |
|
A class for creating an MCP server around a Gradio app. |
|
|
|
Args: |
|
blocks: The Blocks app to create the MCP server for. |
|
""" |
|
|
|
def __init__(self, blocks: "Blocks"): |
|
self.blocks = blocks |
|
self.api_info = self.blocks.get_api_info() |
|
self.mcp_server = self.create_mcp_server() |
|
self.root_path = "" |
|
tool_prefix = utils.get_space() |
|
if tool_prefix: |
|
tool_prefix = tool_prefix.split("/")[-1] + "_" |
|
self.tool_prefix = re.sub(r"[^a-zA-Z0-9]", "_", tool_prefix) |
|
else: |
|
self.tool_prefix = "" |
|
self.tool_to_endpoint = self.get_tool_to_endpoint() |
|
self.warn_about_state_inputs() |
|
|
|
manager = StreamableHTTPSessionManager( |
|
app=self.mcp_server, json_response=False, stateless=True |
|
) |
|
|
|
async def handle_streamable_http( |
|
scope: Scope, receive: Receive, send: Send |
|
) -> None: |
|
await manager.handle_request(scope, receive, send) |
|
|
|
@contextlib.asynccontextmanager |
|
async def lifespan(app: Starlette) -> AsyncIterator[None]: |
|
"""Context manager for managing session manager lifecycle.""" |
|
async with manager.run(): |
|
try: |
|
yield |
|
finally: |
|
pass |
|
|
|
self.lifespan = lifespan |
|
self.manager = manager |
|
self.handle_streamable_http = handle_streamable_http |
|
|
|
def get_route_path(self, request: Request) -> str: |
|
""" |
|
Gets the route path of the MCP server based on the incoming request. |
|
Can be different depending on whether the request is coming from the MCP SSE transport or the HTTP transport. |
|
""" |
|
url = httpx.URL(str(request.url)) |
|
url = url.copy_with(query=None) |
|
url = str(url).rstrip("/") |
|
if url.endswith("/gradio_api/mcp/messages"): |
|
return "/gradio_api/mcp/messages" |
|
else: |
|
return "/gradio_api/mcp/http" |
|
|
|
def get_tool_to_endpoint(self) -> dict[str, str]: |
|
""" |
|
Gets all of the tools that are exposed by the Gradio app and also |
|
creates a mapping from the tool names to the endpoint names in the API docs. |
|
""" |
|
tool_to_endpoint = {} |
|
for endpoint_name, endpoint_info in self.api_info["named_endpoints"].items(): |
|
if endpoint_info["show_api"]: |
|
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) |
|
if block_fn is None or block_fn.fn is None: |
|
continue |
|
fn_name = ( |
|
getattr(block_fn.fn, "__name__", None) |
|
or ( |
|
hasattr(block_fn.fn, "__class__") |
|
and getattr(block_fn.fn.__class__, "__name__", None) |
|
) |
|
or endpoint_name.lstrip("/") |
|
) |
|
tool_name = self.tool_prefix + fn_name |
|
while tool_name in tool_to_endpoint: |
|
tool_name = tool_name + "_" |
|
tool_to_endpoint[tool_name] = endpoint_name |
|
return tool_to_endpoint |
|
|
|
def warn_about_state_inputs(self) -> None: |
|
""" |
|
Warn about tools that have gr.State inputs. |
|
""" |
|
for _, endpoint_name in self.tool_to_endpoint.items(): |
|
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) |
|
if block_fn and any(isinstance(input, State) for input in block_fn.inputs): |
|
warnings.warn( |
|
"This MCP server includes a tool that has a gr.State input, which will not be " |
|
"updated between tool calls. The original, default value of the State will be " |
|
"used each time." |
|
) |
|
|
|
def create_mcp_server(self) -> Server: |
|
""" |
|
Create an MCP server for the given Gradio Blocks app. |
|
|
|
Parameters: |
|
blocks: The Blocks app to create the MCP server for. |
|
|
|
Returns: |
|
The MCP server. |
|
""" |
|
server = Server(str(self.blocks.title or "Gradio App")) |
|
|
|
@server.call_tool() |
|
async def call_tool( |
|
name: str, arguments: dict[str, Any] |
|
) -> list[types.TextContent | types.ImageContent]: |
|
""" |
|
Call a tool on the Gradio app. |
|
|
|
Args: |
|
name: The name of the tool to call. |
|
arguments: The arguments to pass to the tool. |
|
""" |
|
context_request = self.mcp_server.request_context.request |
|
if context_request is None: |
|
raise ValueError( |
|
"Could not find the request object in the MCP server context. This is not expected to happen. Please raise an issue: https://github.com/gradio-app/gradio." |
|
) |
|
route_path = self.get_route_path(context_request) |
|
root_url = route_utils.get_root_url( |
|
request=context_request, |
|
route_path=route_path, |
|
root_path=self.root_path, |
|
) |
|
_, filedata_positions = self.get_input_schema(name) |
|
processed_kwargs = self.convert_strings_to_filedata( |
|
arguments, filedata_positions |
|
) |
|
endpoint_name = self.tool_to_endpoint.get(name) |
|
if endpoint_name is None: |
|
raise ValueError(f"Unknown tool for this Gradio app: {name}") |
|
|
|
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) |
|
assert block_fn is not None |
|
|
|
if endpoint_name in self.api_info["named_endpoints"]: |
|
parameters_info = self.api_info["named_endpoints"][endpoint_name][ |
|
"parameters" |
|
] |
|
processed_args = client_utils.construct_args( |
|
parameters_info, |
|
(), |
|
processed_kwargs, |
|
) |
|
else: |
|
processed_args = [] |
|
processed_args = self.insert_empty_state(block_fn.inputs, processed_args) |
|
output = await self.blocks.process_api( |
|
block_fn=block_fn, |
|
inputs=processed_args, |
|
request=context_request, |
|
) |
|
processed_args = self.pop_returned_state(block_fn.inputs, processed_args) |
|
return self.postprocess_output_data(output["data"], root_url) |
|
|
|
@server.list_tools() |
|
async def list_tools() -> list[types.Tool]: |
|
""" |
|
List all tools on the Gradio app. |
|
""" |
|
tools = [] |
|
for tool_name, endpoint_name in self.tool_to_endpoint.items(): |
|
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) |
|
assert block_fn is not None and block_fn.fn is not None |
|
description, parameters, returns = utils.get_function_description( |
|
block_fn.fn |
|
) |
|
if returns: |
|
description += ( |
|
("" if description.endswith(".") else ".") |
|
+ " Returns: " |
|
+ ", ".join(returns) |
|
) |
|
schema, _ = self.get_input_schema(tool_name, parameters) |
|
tools.append( |
|
types.Tool( |
|
name=tool_name, |
|
description=description, |
|
inputSchema=schema, |
|
) |
|
) |
|
return tools |
|
|
|
return server |
|
|
|
def launch_mcp_on_sse(self, app: Starlette, subpath: str, root_path: str) -> None: |
|
""" |
|
Launch the MCP server on the SSE transport. |
|
|
|
Parameters: |
|
app: The Gradio app to mount the MCP server on. |
|
subpath: The subpath to mount the MCP server on. E.g. "/gradio_api/mcp" |
|
root_path: The root path of the Gradio Blocks app. |
|
""" |
|
messages_path = "/messages/" |
|
sse = SseServerTransport(messages_path) |
|
self.root_path = root_path |
|
|
|
async def handle_sse(request): |
|
try: |
|
async with sse.connect_sse( |
|
request.scope, request.receive, request._send |
|
) as streams: |
|
await self.mcp_server.run( |
|
streams[0], |
|
streams[1], |
|
self.mcp_server.create_initialization_options(), |
|
) |
|
return Response() |
|
except Exception as e: |
|
print(f"MCP SSE connection error: {str(e)}") |
|
raise |
|
|
|
app.mount( |
|
subpath, |
|
Starlette( |
|
routes=[ |
|
Route( |
|
"/schema", |
|
endpoint=self.get_complete_schema, |
|
), |
|
Route("/sse", endpoint=handle_sse), |
|
Mount("/messages/", app=sse.handle_post_message), |
|
Mount("/http/", app=self.handle_streamable_http), |
|
], |
|
), |
|
) |
|
|
|
def get_block_fn_from_endpoint_name( |
|
self, endpoint_name: str |
|
) -> "BlockFunction | None": |
|
""" |
|
Get the BlockFunction for a given endpoint name (e.g. "/predict"). |
|
|
|
Parameters: |
|
endpoint_name: The name of the endpoint to get the BlockFunction for. |
|
|
|
Returns: |
|
The BlockFunction for the given endpoint name, or None if it is not found. |
|
""" |
|
block_fn = next( |
|
( |
|
fn |
|
for fn in self.blocks.fns.values() |
|
if fn.api_name == endpoint_name.lstrip("/") |
|
), |
|
None, |
|
) |
|
return block_fn |
|
|
|
@staticmethod |
|
def insert_empty_state( |
|
inputs: Sequence["Component | BlockContext"], data: list |
|
) -> list: |
|
""" |
|
Insert None placeholder values for any State input components, as State inputs |
|
are not included in the endpoint schema. |
|
""" |
|
for i, input_component_type in enumerate(inputs): |
|
if isinstance(input_component_type, State): |
|
data.insert(i, None) |
|
return data |
|
|
|
@staticmethod |
|
def pop_returned_state( |
|
inputs: Sequence["Component | BlockContext"], data: list |
|
) -> list: |
|
""" |
|
Remove any values corresponding to State output components from the data |
|
as State outputs are not included in the endpoint schema. |
|
""" |
|
for i, input_component_type in enumerate(inputs): |
|
if isinstance(input_component_type, State): |
|
data.pop(i) |
|
return data |
|
|
|
def get_input_schema( |
|
self, |
|
tool_name: str, |
|
parameters: dict[str, str] | None = None, |
|
) -> tuple[dict[str, Any], list[list[str | int]]]: |
|
""" |
|
Get the input schema of the Gradio app API, appropriately formatted for MCP. |
|
|
|
Parameters: |
|
tool_name: The name of the tool to get the schema for, e.g. "predict" |
|
parameters: The description and parameters of the tool to get the schema for. |
|
Returns: |
|
- The input schema of the Gradio app API. |
|
- A list of positions of FileData objects in the input schema. |
|
""" |
|
endpoint_name = self.tool_to_endpoint.get(tool_name) |
|
if endpoint_name is None: |
|
raise ValueError(f"Unknown tool for this Gradio app: {tool_name}") |
|
named_endpoints = self.api_info["named_endpoints"] |
|
endpoint_info = named_endpoints.get(endpoint_name) |
|
assert endpoint_info is not None |
|
|
|
schema = { |
|
"type": "object", |
|
"properties": { |
|
p["parameter_name"]: { |
|
**p["type"], |
|
**( |
|
{"description": parameters[p["parameter_name"]]} |
|
if parameters and p["parameter_name"] in parameters |
|
else {} |
|
), |
|
**( |
|
{"default": p["parameter_default"]} |
|
if "parameter_default" in p and p["parameter_default"] |
|
else {} |
|
), |
|
} |
|
for p in endpoint_info["parameters"] |
|
}, |
|
} |
|
return self.simplify_filedata_schema(schema) |
|
|
|
async def get_complete_schema(self, request) -> JSONResponse: |
|
""" |
|
Get the complete schema of the Gradio app API. For debugging purposes, also used by |
|
the Hugging Face MCP server to get the schema for MCP Spaces without needing to |
|
establish an SSE connection. |
|
|
|
Parameters: |
|
request: The Starlette request object. |
|
|
|
Returns: |
|
A JSONResponse containing a dictionary mapping tool names to their input schemas. |
|
""" |
|
if not self.api_info: |
|
return JSONResponse({}) |
|
|
|
schemas = [] |
|
for tool_name, endpoint_name in self.tool_to_endpoint.items(): |
|
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) |
|
assert block_fn is not None and block_fn.fn is not None |
|
description, parameters, returns = utils.get_function_description( |
|
block_fn.fn |
|
) |
|
if returns: |
|
description += ( |
|
("" if description.endswith(".") else ".") |
|
+ " Returns: " |
|
+ ", ".join(returns) |
|
) |
|
schema, _ = self.get_input_schema(tool_name, parameters) |
|
info = { |
|
"name": tool_name, |
|
"description": description, |
|
"inputSchema": schema, |
|
} |
|
schemas.append(info) |
|
|
|
return JSONResponse(schemas) |
|
|
|
def simplify_filedata_schema( |
|
self, schema: dict[str, Any] |
|
) -> tuple[dict[str, Any], list[list[str | int]]]: |
|
""" |
|
Parses a schema of a Gradio app API to identify positions of FileData objects. Replaces them with base64 |
|
strings while keeping track of their positions so that they can be converted back to FileData objects |
|
later. |
|
|
|
Parameters: |
|
schema: The original schema of the Gradio app API. |
|
|
|
Returns: |
|
A tuple containing the simplified schema and the positions of the FileData objects. |
|
""" |
|
|
|
def is_gradio_filedata(obj: Any, defs: dict[str, Any]) -> bool: |
|
if not isinstance(obj, dict): |
|
return False |
|
|
|
if "$ref" in obj: |
|
ref = obj["$ref"] |
|
if ref.startswith("#/$defs/"): |
|
key = ref.split("/")[-1] |
|
obj = defs.get(key, {}) |
|
else: |
|
return False |
|
|
|
props = obj.get("properties", {}) |
|
meta = props.get("meta", {}) |
|
|
|
if "$ref" in meta: |
|
ref = meta["$ref"] |
|
if ref.startswith("#/$defs/"): |
|
key = ref.split("/")[-1] |
|
meta = defs.get(key, {}) |
|
else: |
|
return False |
|
|
|
type_field = meta.get("properties", {}).get("_type", {}) |
|
default_type = meta.get("default", {}).get("_type") |
|
return ( |
|
type_field.get("const") == "gradio.FileData" |
|
or default_type == "gradio.FileData" |
|
) |
|
|
|
def traverse( |
|
node: Any, |
|
path: list[str | int] | None = None, |
|
defs: dict[str, Any] | None = None, |
|
) -> Any: |
|
if path is None: |
|
path = [] |
|
if defs is None: |
|
defs = {} |
|
|
|
if isinstance(node, dict): |
|
if "$defs" in node: |
|
defs.update(node["$defs"]) |
|
|
|
if is_gradio_filedata(node, defs): |
|
filedata_positions.append(path.copy()) |
|
for key in ["properties", "additional_description", "$defs"]: |
|
node.pop(key, None) |
|
node["type"] = "string" |
|
node["format"] = "a http or https url to a file" |
|
|
|
result = {} |
|
is_schema_root = "type" in node and "properties" in node |
|
for key, value in node.items(): |
|
if is_schema_root and key == "properties": |
|
result[key] = traverse(value, path, defs) |
|
else: |
|
path.append(key) |
|
result[key] = traverse(value, path, defs) |
|
path.pop() |
|
return result |
|
|
|
elif isinstance(node, list): |
|
result = [] |
|
for i, item in enumerate(node): |
|
path.append(i) |
|
result.append(traverse(item, path, defs)) |
|
path.pop() |
|
return result |
|
|
|
return node |
|
|
|
filedata_positions: list[list[str | int]] = [] |
|
simplified_schema = traverse(schema) |
|
return simplified_schema, filedata_positions |
|
|
|
def convert_strings_to_filedata( |
|
self, value: Any, filedata_positions: list[list[str | int]] |
|
) -> Any: |
|
""" |
|
Convert specific string values back to FileData objects based on their positions. |
|
This is used to convert string values (as base64 encoded strings) to FileData |
|
dictionaries so that they can be passed into .preprocess() logic of a Gradio app. |
|
|
|
Parameters: |
|
value: The input data to process, which can be an arbitrary nested data structure |
|
that may or may not contain strings that should be converted to FileData objects. |
|
filedata_positions: List of paths to positions in the input data that should be converted to FileData objects. |
|
|
|
Returns: |
|
The processed data with strings converted to FileData objects where appropriate. Base64 |
|
encoded strings are first saved to a temporary file and then converted to a FileData object. |
|
|
|
Example: |
|
>>> convert_strings_to_filedata( |
|
{"image": "data:image/jpeg;base64,..."}, |
|
[["image"]] |
|
) |
|
>>> {'image': FileData(path='<temporary file path>')}, |
|
""" |
|
|
|
def traverse(node: Any, path: list[str | int] | None = None) -> Any: |
|
if path is None: |
|
path = [] |
|
|
|
if isinstance(node, dict): |
|
return { |
|
key: traverse(value, path + [key]) for key, value in node.items() |
|
} |
|
elif isinstance(node, list): |
|
return [traverse(item, path + [i]) for i, item in enumerate(node)] |
|
elif isinstance(node, str) and path in filedata_positions: |
|
if node.startswith("data:"): |
|
|
|
|
|
return FileData( |
|
path=processing_utils.save_base64_to_cache( |
|
node, DEFAULT_TEMP_DIR |
|
) |
|
) |
|
elif node.startswith(("http://", "https://")): |
|
return FileData(path=node) |
|
else: |
|
raise ValueError( |
|
f"Invalid file data format, provide a url ('http://...' or 'https://...'). Received: {node}" |
|
) |
|
return node |
|
|
|
return traverse(value) |
|
|
|
@staticmethod |
|
def get_image(file_path: str) -> Image.Image | None: |
|
""" |
|
If a filepath is a valid image, returns a PIL Image object. Otherwise returns None. |
|
""" |
|
if not os.path.exists(file_path): |
|
return None |
|
ext = os.path.splitext(file_path.lower())[1] |
|
if ext not in Image.registered_extensions(): |
|
return None |
|
try: |
|
return Image.open(file_path) |
|
except Exception: |
|
return None |
|
|
|
@staticmethod |
|
def get_svg(file_data: Any) -> bytes | None: |
|
""" |
|
If a file_data is a valid FileDataDict with a url that is a data:image/svg+xml, returns bytes of the svg. Otherwise returns None. |
|
""" |
|
if isinstance(file_data, dict) and (url := file_data.get("url")): |
|
if isinstance(url, str) and url.startswith("data:image/svg"): |
|
return unquote(url.split(",", 1)[1]).encode() |
|
else: |
|
return None |
|
else: |
|
return None |
|
|
|
@staticmethod |
|
def get_base64_data(image: Image.Image, format: str) -> str: |
|
""" |
|
Returns a base64 encoded string of the image. |
|
""" |
|
buffer = BytesIO() |
|
image.save(buffer, format=format) |
|
return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
def postprocess_output_data( |
|
self, data: Any, root_url: str |
|
) -> list[types.TextContent | types.ImageContent]: |
|
""" |
|
Postprocess the output data from the Gradio app to convert FileData objects back to base64 encoded strings. |
|
|
|
Parameters: |
|
data: The output data to postprocess. |
|
""" |
|
return_values = [] |
|
data = processing_utils.add_root_url(data, root_url, None) |
|
for output in data: |
|
if svg_bytes := self.get_svg(output): |
|
base64_data = base64.b64encode(svg_bytes).decode("utf-8") |
|
mimetype = "image/svg+xml" |
|
svg_path = processing_utils.save_bytes_to_cache( |
|
svg_bytes, f"{output['orig_name']}", DEFAULT_TEMP_DIR |
|
) |
|
svg_url = f"{root_url}/gradio_api/file={svg_path}" |
|
return_value = [ |
|
types.ImageContent( |
|
type="image", data=base64_data, mimeType=mimetype |
|
), |
|
types.TextContent( |
|
type="text", |
|
text=f"SVG Image URL: {svg_url}", |
|
), |
|
] |
|
elif client_utils.is_file_obj_with_meta(output): |
|
if image := self.get_image(output["path"]): |
|
image_format = image.format or "png" |
|
base64_data = self.get_base64_data(image, image_format) |
|
mimetype = f"image/{image_format.lower()}" |
|
return_value = [ |
|
types.ImageContent( |
|
type="image", data=base64_data, mimeType=mimetype |
|
), |
|
types.TextContent( |
|
type="text", |
|
text=f"Image URL: {output['url'] or output['path']}", |
|
), |
|
] |
|
else: |
|
return_value = [ |
|
types.TextContent( |
|
type="text", text=str(output["url"] or output["path"]) |
|
) |
|
] |
|
else: |
|
return_value = [types.TextContent(type="text", text=str(output))] |
|
return_values.extend(return_value) |
|
return return_values |
|
|