Spaces:
Runtime error
Runtime error
File size: 1,457 Bytes
6db7b4c 028951c f94be16 028951c 6db7b4c 028951c 6db7b4c 028951c 6db7b4c 028951c 6db7b4c 028951c 6db7b4c 028951c f94be16 028951c 6db7b4c 028951c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio
import pandas as pd
from matplotlib import pyplot as plt
from config import CONFIG
from data import get_extra_tokens, BenetechOutput, ChartType
from model import predict_string, build_model
def gradio_visualize_prediction(string):
string = string.removeprefix(get_extra_tokens().benetech_prompt)
if not BenetechOutput.does_string_match_expected_pattern(string):
return
benetech_output = BenetechOutput.from_string(string)
x = benetech_output.x_data[: len(benetech_output.y_data)]
y = benetech_output.y_data[: len(benetech_output.x_data)]
df = pd.DataFrame(dict(x=x, y=y))
plt_plot = {
ChartType.line: plt.plot,
ChartType.scatter: plt.scatter,
ChartType.horizontal_bar: plt.barh,
ChartType.vertical_bar: plt.bar,
ChartType.dot: plt.scatter,
}
plt_plot[benetech_output.chart_type](x, y)
plt.xticks(rotation=30)
plt.savefig("plot.png")
...
def main():
config = CONFIG
config.pretrained_model_name = "checkpoint"
model = build_model(config)
interface = gradio.Interface(
title="Making graphs accessible",
description="Generate textual representation of a graph\n"
"https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
fn=lambda image: predict_string(image, model),
inputs="image",
outputs="text",
examples="examples",
)
interface.launch()
|