Spaces:
Runtime error
Runtime error
File size: 1,161 Bytes
eddf9fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import os
import gradio as gr
os.system("pip install setfit")
from setfit import SetFitModel
default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
HF_HOME = os.environ.get("HF_HOME", default_hf_home)
coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"
labels = ["black", "green", "red", "blue", "white"]
model = SetFitModel.from_pretrained(coloridentity_model, cache_dir=HF_HOME)
def get_preds(input_text: str) -> tuple[str, dict[str, float]]:
preds = model.predict_proba(input_text)
pred_dict = {label: preds[i] for i, label in enumerate(labels)}
color_identity = "/".join([color for i, color in enumerate(labels) if preds[i] > 0.5])
if color_identity == "":
color_identity = "colorless"
return color_identity, pred_dict
iface = gr.Interface(
fn=get_preds,
inputs=gr.Textbox(),
outputs=[
gr.Textbox(),
gr.Label(),
],
title="Magic the Gathering Color Identity Classifier",
description="Enter card name and ability text to classify the color identity of the card.",
allow_flagging=False,
)
iface.launch(show_api=True) |