import pandas as pd import plotly.express as px import gradio as gr data_path = '0926-OCRBench-opensource.csv' data = pd.read_csv(data_path).fillna(0) # set the data types for the columns dtype_dict = { "Model": str, "Param (B)": float, "OCRBench":int, "Text Recognition":int, "Scene Text-centric VQA":int, "Document Oriented VQA":int, "KIE":int, "Handwritten Math Expression Recognition":int} # preprocess the dataframe data_valid = data[:25].copy() data_valid = data_valid.astype(dtype_dict) data_valid.drop(columns=['Unnamed: 11'], inplace=True) # Add a new column that assigns categories to Model A, Model B, and Model C, and 'Other' to the rest def categorize_model(model): if model in ["H2OVL-Mississippi-2B", "H2OVL-Mississippi-1B"]: return "H2OVLs" elif model.startswith("doctr"): # Third group for ocr models return "traditional ocr models" else: return "Other" # Apply the categorization to create a new column data_valid["Category"] = data_valid["Model"].apply(categorize_model) # ploting def plot_metric(selected_metric): filtered_data = data_valid[data_valid[selected_metric] !=0 ] # Create the scatter plot with different colors for "Special" and "Other" fig = px.scatter( filtered_data, x="Param (B)", y=selected_metric, text="Model", color="Category", # Different color for Special and Other categories title=f"{selected_metric} vs Model Size" ) fig.update_traces(marker=dict(size=10), mode='markers+text', textposition="middle right", textfont=dict(size=10)) # Extend the x-axis range max_x_value = filtered_data["Param (B)"].max() fig.update_layout( xaxis_range=[0, max_x_value + 5], # Extend the x-axis range to give more space for text xaxis_title="Model Size (B)", yaxis_title=selected_metric, showlegend=False, height=800, margin=dict(t=50, l=50, r=100, b=50), # Increase right margin for more space ) # Use texttemplate to ensure full model name is displayed fig.update_traces(texttemplate='%{text}') return fig # Gradio Blocks Interface def create_interface(): with gr.Blocks() as interface: with gr.Row(): with gr.Column(scale=4): # Column for the plot (takes 4 parts of the total space) plot = gr.Plot(value=plot_metric("OCRBench"), label="OCR Benchmark Metrics") # default plot component initially with gr.Column(scale=1): # Column for the dropdown (takes 1 part of the total space) metrics = list(data_valid.columns[5:-1]) # List of metric columns (excluding 'Model' and 'Parameter Size') dropdown = gr.Dropdown(metrics, label="Select Metric", value="OCRBench") # Update the plot when dropdown selection changes dropdown.change(fn=plot_metric, inputs=dropdown, outputs=plot) return interface # Launch the interface if __name__ == "__main__": create_interface().launch()