File size: 2,800 Bytes
f36568f
3b7c2cc
 
 
 
 
 
 
 
 
 
 
b5db008
 
3b7c2cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed5103
c597797
3b7c2cc
32d52a2
3b7c2cc
 
 
 
9724cad
3b7c2cc
 
ca444f4
3b7c2cc
ca444f4
3b7c2cc
 
 
 
9724cad
 
3b7c2cc
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#for some reason the status of this demo is 'undefined' 
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import re

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModelForSequenceClassification.from_pretrained("Qilex/colorpAI-monocolor")

def round_to_2(num):
    return round(num, 2)

def format_output(out_list):
    if len(out_list) == 1:
        out_list = out_list[0]
    for dictionary in out_list:
        if dictionary["label"] =='W':
            white = round_to_2(dictionary["score"])
    for dictionary in out_list:
        if dictionary["label"] =='U':
            blue = round_to_2(dictionary["score"])
    for dictionary in out_list:
        if dictionary["label"] =='B':
            black = round_to_2(dictionary["score"])            
    for dictionary in out_list:
        if dictionary["label"] =='R':
            red = round_to_2(dictionary["score"])           
    for dictionary in out_list:
        if dictionary["label"] =='G':
            green = round_to_2(dictionary["score"])
    for dictionary in out_list:
        if dictionary["label"] =='C':
            colorless = round_to_2(dictionary["score"])
    out= {}
    out['White'] = white
    out['Blue'] = blue
    out['Black'] = black
    out['Red'] = red
    out['Green'] = green
    out['Colorless'] = colorless
    return out

def predict(card):  
    return predictor_lg(card)

def remove_colored_pips(text):
    pattern = r'\{[W,U,B,R,G,C]+/*[W,U,B,R,G,C]*\}'
    return(re.sub(pattern, '{?}', text))

def preprocess_text(text):
    return remove_colored_pips(text)

def categorize(Card):
    text = preprocess_text(Card)
    prediction = predict(text)
    print(prediction)
    return format_output(prediction)
    
title = "Color pAI Version 1.0"
description = """
Color pAI is trained on around 18,000 Magic: the Gathering cards. 
<br>
Input a card text using Scryfall syntax, and the model will tell evaluate which color it is most likely to be.
<br>Replace any card names with the word CARDNAME, and mana symbols with the uppercase letter encased U in curly brackets {U}
<br>
<br>This only works on monocolored and colorless cards.
<br>
"""
article = '''
<br>
Magic: the Gathering is property of Wizards of the Coast. This project is made possible under their 
<a href="https://company.wizards.com/en/legal/fancontentpolicy" target = 'blank'>fan content policy</a>.
'''
predictor_lg = TextClassificationPipeline(model=model, tokenizer=tokenizer, function_to_apply = 'softmax', top_k = 6)

gr.Interface(
    fn=categorize,
    inputs=gr.Textbox(lines=1, placeholder="Type card text here."),
    outputs=gr.Label(num_top_classes=6),
    title=title,
    description=description,
    article = article,
).launch()