MLflow_mcp / app.py
tuntun's picture
fix bug
bab875d
import json
import gradio as gr
from mcp_mlflow_tools import (
set_tracking_uri,
get_system_info,
list_experiments,
create_experiment,
register_model,
search_runs,
list_registered_models,
get_model_info
)
def create_interface():
with gr.Blocks(title="MLflow MCP Service") as app:
gr.Markdown("# MLflow MCP Service")
gr.Markdown("A service that exposes MLflow functionality through a web interface and API endpoints.")
with gr.Tab("Tracking & System Info"):
with gr.Group():
gr.Markdown("## Set Tracking URI")
uri_input = gr.Textbox(label="MLflow Tracking URI")
uri_output = gr.JSON(label="Result")
uri_button = gr.Button("Set URI")
uri_button.click(
fn=set_tracking_uri,
inputs=uri_input,
outputs=uri_output
)
with gr.Group():
gr.Markdown("## Get System Info")
sys_info_output = gr.JSON(label="System Information")
sys_info_button = gr.Button("Get Info")
sys_info_button.click(
fn=get_system_info,
inputs=[],
outputs=sys_info_output
)
with gr.Tab("Experiment Management"):
with gr.Group():
gr.Markdown("## List Experiments")
exp_list_output = gr.JSON(label="Experiments")
exp_list_button = gr.Button("List Experiments")
exp_list_button.click(
fn=list_experiments,
inputs=[],
outputs=exp_list_output
)
with gr.Group():
gr.Markdown("## Create Experiment")
exp_name_input = gr.Textbox(label="Experiment Name")
exp_tags_input = gr.Textbox(label="Tags (JSON format)", placeholder='{"key": "value"}')
exp_create_output = gr.JSON(label="Result")
exp_create_button = gr.Button("Create Experiment")
def create_exp_with_tags(name, tags_str):
"""Create a new experiment. Given the name and tags"""
try:
tags = json.loads(tags_str) if tags_str else None
except json.JSONDecodeError:
return {"error": True, "message": "Invalid JSON format for tags"}
return create_experiment(name, tags)
exp_create_button.click(
fn=create_exp_with_tags,
inputs=[exp_name_input, exp_tags_input],
outputs=exp_create_output
)
with gr.Tab("Model Registry"):
with gr.Group():
gr.Markdown("## Register Model")
reg_run_id = gr.Textbox(label="Run ID")
reg_artifact_path = gr.Textbox(label="Artifact Path")
reg_model_name = gr.Textbox(label="Model Name")
reg_output = gr.JSON(label="Result")
reg_button = gr.Button("Register Model")
reg_button.click(
fn=register_model,
inputs=[reg_run_id, reg_artifact_path, reg_model_name],
outputs=reg_output
)
with gr.Group():
gr.Markdown("## List Registered Models")
list_models_output = gr.JSON(label="Models")
list_models_button = gr.Button("List Models")
list_models_button.click(
fn=list_registered_models,
inputs=[],
outputs=list_models_output
)
with gr.Group():
gr.Markdown("## Get Model Info")
model_info_name = gr.Textbox(label="Model Name")
model_info_output = gr.JSON(label="Model Information")
model_info_button = gr.Button("Get Info")
model_info_button.click(
fn=get_model_info,
inputs=model_info_name,
outputs=model_info_output
)
with gr.Tab("Run Search"):
with gr.Group():
gr.Markdown("## Search Runs")
search_exp_id = gr.Textbox(label="Experiment ID")
search_filter = gr.Textbox(label="Filter String")
search_order_by = gr.Textbox(label="Order By")
search_max_results = gr.Number(label="Max Results", value=100, precision=0)
search_output = gr.JSON(label="Search Results")
search_button = gr.Button("Search")
search_button.click(
fn=search_runs,
inputs=[search_exp_id, search_filter, search_order_by, search_max_results],
outputs=search_output
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch(mcp_server=True)