MixoMax commited on
Commit
fd40de5
1 Parent(s): b2684c3

Upload mtgpt2.py

Browse files
Files changed (1) hide show
  1. mtgpt2.py +176 -0
mtgpt2.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MtGPT2
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1HMq9Cp_jhqc9HlUipLXi8SC1mHQhSgGn
8
+ """
9
+
10
+ #@title Setup + Download Model
11
+ # Install and load dependencies
12
+ import locale
13
+ locale.getpreferredencoding = lambda: "UTF-8"
14
+
15
+ !pip install aitextgen
16
+ !pip install pytorch-lightning==1.7.0
17
+
18
+ import sys
19
+ from jinja2 import Template
20
+ from aitextgen import aitextgen
21
+
22
+ try:
23
+ from google.colab import files
24
+ except ImportError:
25
+ pass
26
+
27
+ # Download and load the model. Set to_gpu=False if running on a CPU.
28
+ ai = aitextgen(model="minimaxir/magic-the-gathering", to_gpu=True)
29
+
30
+ # This template is similar to Scryfall card formatting
31
+ TEMPLATE = Template(
32
+ """{{ c.name }}{% if c.manaCost %} {{ c.manaCost }}{% endif %}
33
+ {{ c.type }}
34
+ {{ c.text }}{% if c.power %}
35
+ {{ c.power }}/{{ c.toughness }}{% endif %}{% if c.loyalty %}
36
+ Loyalty: {{ c.loyalty }}{% endif %}"""
37
+ )
38
+
39
+ def render_card(card_dict):
40
+ card = TEMPLATE.render(c=card_dict)
41
+ if card_dict["name"]:
42
+ card = card.replace("~", card_dict["name"])
43
+ return card
44
+
45
+ prompt="\u003C|type|>Creature - Human \u003C|name|>Rezo the Destroyer of Politics \u003C|manaCost|> 2 {B/G}" #@param {type:"string"}
46
+ temperature = 0.7 #@param {type:"slider", min:0.1, max:1.2, step:0.1}
47
+ to_file = False #@param {type:"boolean"}
48
+
49
+ n = 100 if to_file else 8
50
+
51
+ cards = ai.generate(n=n,
52
+ schema=True,
53
+ prompt=prompt,
54
+ temperature=temperature,
55
+ return_as_list=True)
56
+
57
+ cards = list(map(render_card, cards))
58
+
59
+ if to_file:
60
+ file_path = "cards.txt"
61
+ with open(file_path, "w", encoding="utf-8") as f:
62
+ for card in cards:
63
+ f.write("{}\n{}".format(card, "=" * 20 + "\n"))
64
+ if "google.colab" in sys.modules:
65
+ files.download(file_path)
66
+ else:
67
+ print(("\n" + "=" * 20 + "\n").join(cards))
68
+
69
+ def generate_cards(
70
+ n_cards: int = 8,
71
+ temperature: float = 0.75,
72
+ name: str = "",
73
+ manaCost: str = "",
74
+ type: str = "",
75
+ text: str = "",
76
+ power: str = "",
77
+ toughness: str = "",
78
+ loyalty: str = ""
79
+ ):
80
+ #ensure n_cards is never 0 or negative
81
+ n_cards = int(n_cards)
82
+ if n_cards < 1:
83
+ n_cards = 1
84
+
85
+
86
+ #change manaCost from Format:
87
+ # 2UG
88
+ #to:
89
+ #{2}{U}{G}
90
+
91
+ manaCost_str = ""
92
+
93
+ for char in manaCost:
94
+ manaCost_str += "{"
95
+ manaCost_str += char
96
+ manaCost_str += "}"
97
+
98
+
99
+
100
+
101
+ prompt_str = ""
102
+
103
+ token_dict = {
104
+ "<|name|>": name,
105
+ "<|manaCost|>": manaCost_str,
106
+ "<|type|>": type,
107
+ "<|text|>": text,
108
+ "<|power|>": power,
109
+ "<|toughness|>": toughness,
110
+ "<|loyalty|>": loyalty
111
+ }
112
+
113
+ # Convert the token_dict into a formatted prompt string
114
+ for token, value in token_dict.items():
115
+ if value:
116
+ prompt_str += f"{token}{value}"
117
+
118
+ # Generate the cards using the prompt string and other parameters
119
+ cards = ai.generate(
120
+ n=n_cards,
121
+ schema=True,
122
+ prompt=prompt_str,
123
+ temperature=temperature,
124
+ return_as_list=True
125
+ )
126
+
127
+ cards = list(map(render_card, cards))
128
+
129
+ out_str = "\n=====\n".join(cards)
130
+
131
+ replacements = {
132
+ "{G}": "🌲",
133
+ "{U}": "🌊",
134
+ "{R}": "🔥",
135
+ "{B}": "💀",
136
+ "{W}": "☀️",
137
+ "{T}": "↩️",
138
+ #"{1}": "1⃣️",
139
+ #"{2}": "2⃣️",
140
+ #"{3}": "3⃣️",
141
+ #"{4}": "4⃣️",
142
+ #"{5}": "5⃣️",
143
+ #"{6}": "6⃣️",
144
+ #"{7}": "7⃣️",
145
+ #"{8}": "8⃣️",
146
+ #"{9}": "9⃣️",
147
+ }
148
+
149
+ for key, value in replacements.items():
150
+ out_str = out_str.replace(key, value)
151
+
152
+
153
+ return out_str
154
+
155
+ # Commented out IPython magic to ensure Python compatibility.
156
+ # %pip install gradio
157
+ import gradio as gr
158
+
159
+ iface = gr.Interface(
160
+ fn = generate_cards,
161
+ inputs=[
162
+ gr.Slider(minimum = 2, maximum=16, step=1, value=8),
163
+ gr.Slider(minimum = 0.1, maximum=1.5, step=0.01, value=0.75),
164
+ gr.Textbox(),
165
+ gr.Textbox(),
166
+ gr.Textbox(),
167
+ gr.Textbox(),
168
+ gr.Textbox(),
169
+ gr.Textbox(),
170
+ gr.Textbox(),
171
+ ],
172
+ outputs=gr.Textbox(),
173
+ title = "GPT-2 Powered MTG Card Generator",
174
+ description = "Enter Manacost as '2UG' for 2 colorless + Blue + Green mana. \n\n Temperature is recomended between 0.4 and 0.9. Anything above 1 will lead to random Chaos and very low values will just be boring."
175
+ )
176
+ iface.launch()