File size: 2,640 Bytes
3b7c2cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import re

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModelForSequenceClassification.from_pretrained("Qilex/colorpAI-monocolor")

def round_to_2(num):
    return round(num, 2)

def format_output(out_list):
    white = 0
    for dictionary in out_list:
        if dictionary["label"] =='W':
            white = round_to_2(dictionary["score"])
    for dictionary in out_list:
        if dictionary["label"] =='U':
            blue = round_to_2(dictionary["score"])
    for dictionary in out_list:
        if dictionary["label"] =='B':
            black = round_to_2(dictionary["score"])            
    for dictionary in out_list:
        if dictionary["label"] =='R':
            red = round_to_2(dictionary["score"])           
    for dictionary in out_list:
        if dictionary["label"] =='G':
            green = round_to_2(dictionary["score"])
    for dictionary in out_list:
        if dictionary["label"] =='C':
            colorless = round_to_2(dictionary["score"])
    out= {}
    out['White'] = white
    out['Blue'] = blue
    out['Black'] = black
    out['Red'] = red
    out['Green'] = green
    out['Colorless'] = colorless
    return out

def predict(card):  
    return predictor_lg(card)

def remove_colored_pips(text):
    pattern = r'\{[W,U,B,R,G,C]+/*[W,U,B,R,G,C]*\}'
    return(re.sub(pattern, '{?}', text))

def preprocess_text(text):
    return remove_colored_pips(text)

def categorize(card):
    text = preprocess_text(card)
    prediction = predict(text)
    
    return format_output(prediction)
    
title = "Color pAI Version 1.0"
description = """
Color pAI is trained on around 18,000 Magic: the Gathering cards made available under Wizards of the Coast's 
<a href="https://company.wizards.com/en/legal/fancontentpolicy" target = 'blank'>fan content policy</a>.
<br>
Input a card text using Scryfall syntax, and the model will tell evaluate which color it is most likely to be.
<br>Replace any card names with the word CARDNAME.
<br>
<br>This only works on monocolored cards. Version 2 will also handle multicolored cards.
<br>
"""
article = '''
<br>
Magic: the Gathering is property of Wizards of the Coast.
'''
predictor_lg = TextClassificationPipeline(model=model, tokenizer=tokenizer, function_to_apply = 'softmax', top_k = 6)

gr.Interface(
    fn=categorize,
    inputs=gr.Textbox(lines=1, placeholder="Type card text here."),
    outputs=gr.Label(num_top_classes=6),
    title=title,
    description=description,
    article = article,
).launch()