import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline import re tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") model = AutoModelForSequenceClassification.from_pretrained("Qilex/colorpAI-monocolor") def round_to_2(num): return round(num, 2) def format_output(out_list): white = 0 for dictionary in out_list: if dictionary["label"] =='W': white = round_to_2(dictionary["score"]) for dictionary in out_list: if dictionary["label"] =='U': blue = round_to_2(dictionary["score"]) for dictionary in out_list: if dictionary["label"] =='B': black = round_to_2(dictionary["score"]) for dictionary in out_list: if dictionary["label"] =='R': red = round_to_2(dictionary["score"]) for dictionary in out_list: if dictionary["label"] =='G': green = round_to_2(dictionary["score"]) for dictionary in out_list: if dictionary["label"] =='C': colorless = round_to_2(dictionary["score"]) out= {} out['White'] = white out['Blue'] = blue out['Black'] = black out['Red'] = red out['Green'] = green out['Colorless'] = colorless return out def predict(card): return predictor_lg(card) def remove_colored_pips(text): pattern = r'\{[W,U,B,R,G,C]+/*[W,U,B,R,G,C]*\}' return(re.sub(pattern, '{?}', text)) def preprocess_text(text): return remove_colored_pips(text) def categorize(card): text = preprocess_text(card) prediction = predict(text) return format_output(prediction) title = "Color pAI Version 1.0" description = """ Color pAI is trained on around 18,000 Magic: the Gathering cards made available under Wizards of the Coast's fan content policy.
Input a card text using Scryfall syntax, and the model will tell evaluate which color it is most likely to be.
Replace any card names with the word CARDNAME.

This only works on monocolored cards. Version 2 will also handle multicolored cards.
""" article = '''
Magic: the Gathering is property of Wizards of the Coast. ''' predictor_lg = TextClassificationPipeline(model=model, tokenizer=tokenizer, function_to_apply = 'softmax', top_k = 6) gr.Interface( fn=categorize, inputs=gr.Textbox(lines=1, placeholder="Type card text here."), outputs=gr.Label(num_top_classes=6), title=title, description=description, article = article, ).launch()