import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import re
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModelForSequenceClassification.from_pretrained("Qilex/colorpAI-monocolor")
def round_to_2(num):
return round(num, 2)
def format_output(out_list):
white = 0
for dictionary in out_list:
if dictionary["label"] =='W':
white = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='U':
blue = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='B':
black = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='R':
red = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='G':
green = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='C':
colorless = round_to_2(dictionary["score"])
out= {}
out['White'] = white
out['Blue'] = blue
out['Black'] = black
out['Red'] = red
out['Green'] = green
out['Colorless'] = colorless
return out
def predict(card):
return predictor_lg(card)
def remove_colored_pips(text):
pattern = r'\{[W,U,B,R,G,C]+/*[W,U,B,R,G,C]*\}'
return(re.sub(pattern, '{?}', text))
def preprocess_text(text):
return remove_colored_pips(text)
def categorize(card):
text = preprocess_text(card)
prediction = predict(text)
return format_output(prediction)
title = "Color pAI Version 1.0"
description = """
Color pAI is trained on around 18,000 Magic: the Gathering cards made available under Wizards of the Coast's
fan content policy.
Input a card text using Scryfall syntax, and the model will tell evaluate which color it is most likely to be.
Replace any card names with the word CARDNAME.
This only works on monocolored cards. Version 2 will also handle multicolored cards.
"""
article = '''
Magic: the Gathering is property of Wizards of the Coast.
'''
predictor_lg = TextClassificationPipeline(model=model, tokenizer=tokenizer, function_to_apply = 'softmax', top_k = 6)
gr.Interface(
fn=categorize,
inputs=gr.Textbox(lines=1, placeholder="Type card text here."),
outputs=gr.Label(num_top_classes=6),
title=title,
description=description,
article = article,
).launch()