MixoMax commited on
Commit
eddf9fc
1 Parent(s): 8a93729

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ os.system("pip install setfit")
6
+
7
+ from setfit import SetFitModel
8
+
9
+ default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
10
+ HF_HOME = os.environ.get("HF_HOME", default_hf_home)
11
+
12
+ coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"
13
+
14
+ labels = ["black", "green", "red", "blue", "white"]
15
+
16
+ model = SetFitModel.from_pretrained(coloridentity_model, cache_dir=HF_HOME)
17
+
18
+
19
+ def get_preds(input_text: str) -> tuple[str, dict[str, float]]:
20
+ preds = model.predict_proba(input_text)
21
+ pred_dict = {label: preds[i] for i, label in enumerate(labels)}
22
+
23
+ color_identity = "/".join([color for i, color in enumerate(labels) if preds[i] > 0.5])
24
+
25
+ if color_identity == "":
26
+ color_identity = "colorless"
27
+
28
+ return color_identity, pred_dict
29
+
30
+ iface = gr.Interface(
31
+ fn=get_preds,
32
+ inputs=gr.Textbox(),
33
+ outputs=[
34
+ gr.Textbox(),
35
+ gr.Label(),
36
+ ],
37
+ title="Magic the Gathering Color Identity Classifier",
38
+ description="Enter card name and ability text to classify the color identity of the card.",
39
+ allow_flagging=False,
40
+ )
41
+
42
+ iface.launch(show_api=True)