mariobedrock / app.py
Banjo Obayomi
remove opus
b2d4ee3
import json
import uuid
import boto3
import gradio as gr
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from mario_gpt.lm import MarioLM
from mario_gpt.utils import add_prompt_to_image, convert_level_to_png
# import anthropic
# opus_client = anthropic.Anthropic()
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
region_name="us-east-1",
)
def get_raw_text(level_data):
raw_text = ""
for line in level_data:
raw_text += line + "\n"
return raw_text
def combine_levels(level_arrays):
num_rows = len(level_arrays[0])
combined_level = []
for row in range(num_rows):
combined_row = ""
for level in level_arrays:
combined_row += level[row]
combined_level.append(combined_row)
return combined_level
def write_level_to_file(level_data, file_name):
with open(file_name, "w") as file:
for line in level_data:
file.write(line + "\n")
def clean_level_data(input_string):
# Find the start and end indices of the level data
start_index = input_string.find("[")
end_index = input_string.rfind("]")
# Extract the level data
level_data = input_string[start_index + 1 : end_index]
# Split the level data into lines
lines = level_data.split(",")
# Clean each line
cleaned_lines = []
for line in lines:
# Remove leading and trailing whitespace and quotes
cleaned_line = line.strip().strip("'")
# Ensure the line has exactly 50 characters
if len(cleaned_line) < 50:
cleaned_line += "-" * (50 - len(cleaned_line))
elif len(cleaned_line) > 50:
cleaned_line = cleaned_line[:50]
cleaned_lines.append(cleaned_line)
return cleaned_lines
def call_llama3_70b(system_prompt, prompt, temperature):
llama_prompt = f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
{system_prompt}
<|eot_id|>
<|begin_of_text|>
<|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|>
"""
prompt_config = {
"prompt": llama_prompt,
"max_gen_len": 2048,
"top_p": 0.9,
"temperature": temperature,
}
body = json.dumps(prompt_config)
modelId = "meta.llama3-70b-instruct-v1:0"
accept = "application/json"
contentType = "application/json"
response = bedrock_runtime.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
results = response_body["generation"].strip()
return results
def call_llama3_8b(system_prompt, prompt, temperature):
llama_prompt = f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
{system_prompt}
<|eot_id|>
<|begin_of_text|>
<|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|>
"""
prompt_config = {
"prompt": llama_prompt,
"max_gen_len": 2048,
"top_p": 0.9,
"temperature": temperature,
}
body = json.dumps(prompt_config)
modelId = "meta.llama3-8b-instruct-v1:0"
accept = "application/json"
contentType = "application/json"
response = bedrock_runtime.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
results = response_body["generation"].strip()
return results
# def call_claude_3_opus(system_prompt, prompt, temperature):
# message = opus_client.messages.create(
# model="claude-3-opus-20240229",
# max_tokens=4096,
# system=system_prompt,
# temperature=temperature,
# messages=[{"role": "user", "content": prompt}],
# )
# return message.content[0].text
# def call_claude_3_opus(system_prompt, prompt, temperature):
# prompt_config = {
# "anthropic_version": "bedrock-2023-05-31",
# "max_tokens": 4096,
# "system": system_prompt,
# "messages": [
# {
# "role": "user",
# "content": [
# {"type": "text", "text": prompt},
# ],
# }
# ],
# "temperature": temperature,
# }
# body = json.dumps(prompt_config)
# modelId = "anthropic.claude-3-opus-20240229-v1:0"
# accept = "application/json"
# contentType = "application/json"
# response = bedrock_runtime.invoke_model(
# body=body, modelId=modelId, accept=accept, contentType=contentType
# )
# response_body = json.loads(response.get("body").read())
# results = response_body.get("content")[0].get("text")
# return results
# Call Claude model
def call_claude_3_sonnet(system_prompt, prompt, temperature):
prompt_config = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 4096,
"system": system_prompt,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
}
],
"temperature": temperature,
}
body = json.dumps(prompt_config)
modelId = "anthropic.claude-3-sonnet-20240229-v1:0"
accept = "application/json"
contentType = "application/json"
response = bedrock_runtime.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
results = response_body.get("content")[0].get("text")
return results
def call_claude_3_haiku(system_prompt, prompt, temperature):
prompt_config = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 4096,
"system": system_prompt,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
}
],
"temperature": temperature,
}
body = json.dumps(prompt_config)
modelId = "anthropic.claude-3-haiku-20240307-v1:0"
accept = "application/json"
contentType = "application/json"
response = bedrock_runtime.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
results = response_body.get("content")[0].get("text")
return results
def call_command_r_plus(system_prompt, prompt, temperature):
chat_history = [
{"role": "USER", "message": system_prompt},
]
prompt_config = {
"message": prompt,
"max_tokens": 4096,
"chat_history": chat_history,
"temperature": temperature,
}
body = json.dumps(prompt_config)
modelId = "cohere.command-r-plus-v1:0"
accept = "application/json"
contentType = "application/json"
response = bedrock_runtime.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
results = response_body.get("text")
return results
def call_command_r(system_prompt, prompt, temperature):
chat_history = [
{"role": "USER", "message": system_prompt},
]
prompt_config = {
"message": prompt,
"max_tokens": 4096,
"chat_history": chat_history,
"temperature": temperature,
}
body = json.dumps(prompt_config)
modelId = "cohere.command-r-v1:0"
accept = "application/json"
contentType = "application/json"
response = bedrock_runtime.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
results = response_body.get("text")
return results
system_prompt_text = """
As an esteemed level designer renowned for creating some of the top 100 levels in Super Mario Maker, you are tasked with crafting a playable section for the original Super Mario on NES. Your extensive experience and creativity are key to designing levels that are not only challenging but also immensely enjoyable. Use the following symbols to represent different game elements, ensuring each level is a masterpiece of design:
<symbols>
- = "Sky"
X = "Unbreakable Block"
E = "Enemy"
o = "Coin"
S = "Breakable Block"
? = "Question Block"
[] = "Pipe"
<> = "End of Pipe"
</symbols>
Adhere to these level layout specifications:
<level guidelines>
Pipes should be vertical and follow this format:
<>
[]
[]
Ensure there is a clear and navigable path that Mario can follow from the start to the end of the level. This path may involve jumping on blocks or pipes, running on blocks.
The path should be continuous and not lead Mario into any dead ends or impossible situations.
Place unbreakable blocks (X) or other platform elements strategically to create a solid foundation for Mario to walk on. Avoid creating large gaps or sections without any ground or platforms, as Mario needs a surface to stand on.
Adjust the complexity and elements based on the specific level request, ensuring that Mario can always complete the level successfully by following the designated path.
</level guidelines>
Here are some examples to follow:
<example>
<input>
Make a simple level that has no enemies, but lots and lots of coins and blocks
</input>
<output>
['-------------------------------------------------o', '-------------------------------------------------o', '------------------------------------------------oo', '-----------------------------------------------ooo', '----------------------------------------------oooo', '---------------------------------------------ooooo', '--------------------------------------------oooooo', '-------------------------------------------ooooooo', 'ooooooooooooooooooooooooooooooooooooooooooooooooo-', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-']
</output>
</example>
<example>
<input>
Design a level with blocks arranged in a pyramid-like shape, with coins scattered around the base and Goombas guarding the top.
</input>
<output>
['-------------------------------------------------o', '-------------------------------------------------o', '--------------------------------------------------', '--------------------------------------------------', '--------------------------------------------------', '--------------------------------------------------', '----------------EEE-------------------------------', '--------------ooooooo-----------------------------',
'------------ooo?S?Sooo----------------------------', '-----------oooSSSSSoooo---------------------------', '----------oooSSSXSSSoooo--------------------------', '---------oooSSSXXXSSSoooo-------------------------', '--------oooSSSXXXXXSSSoooo------------------------', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX']
</output>
</example>
<example>
<input>
Generate a level with a few pipes, many coins. make sure there are around 10 enemies.
</input>
<output>
['--------------------------------------------------', '--------------------------------------------------', '-------------------------o------ooo---------------', '-----------------------oooooooooooo---------------', '----------------------o-o-oooooooo-o--------------', '---------------------ooooo----E--ooo--------------', '-----------------E---o-------E----o---------------', '------------E----ooooo-----E-----ooo-----E--------', '-----------ooooooo-----E--ooooooo---E-------------', '-E--<>---oooo----E--<>--oo-----E-----<>----------E',
'-X--[]--X----E--XX-[]--X--E---XX----[]------E---XX',
'XX--[]--XXXXSSSXX--[]--XXSSSSSXX----[]----SSSSSXXX',
'XXX-[]--XXXXXXXXXX-[]--XXXXXXXXXX---[]---XXXXXXXXX', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX']
</output>
</example>
Generate the level section as a 2D array, where each row is represented as a string of characters. The level section should be 14 rows tall and 50 columns wide. Only return the 2D array of characters.
Remember, your creations should challenge players but remain fair. Use your expertise to weave together obstacles and rewards, encouraging exploration and skillful play. Always ensure that Mario has a clear and navigable route to finish the level, and provide ample block tiles for Mario to walk on.
"""
mario_lm = MarioLM()
# device = torch.device('cuda')
# mario_lm = mario_lm.to(device)
TILE_DIR = "mario_gpt/data/tiles"
app = FastAPI()
def make_html_file(generated_level):
level_text = generated_level
unique_id = uuid.uuid1()
with open(f"static/demo-{unique_id}.html", "w", encoding="utf-8") as f:
f.write(
f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Mario Game</title>
<script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>
<body>
</body>
<script>
cheerpjInit().then(function () {{
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
}});
cheerpjCreateDisplay(612, 600);
cheerpjRunJar("/app/static/mario.jar");
</script>
</html>"""
)
return f"demo-{unique_id}.html"
def generate(model, prompt, temperature, system_prompt=system_prompt_text):
print(f"Using prompt: {prompt}")
if system_prompt == "":
system_prompt = system_prompt_text
# # prompt 3 times
# prompts = [prompt, prompt, prompt]
# levels_array = []
# for index, prompt in enumerate(prompts):
# level = call_claude_3_sonnet(system_prompt, prompt)
# cleaned_level = clean_level_data(level)
# levels_array.append(cleaned_level)
# final_level = combine_levels(levels_array)
# raw_level_text = get_raw_text(final_level)
if model == "Claude Sonnet":
level = call_claude_3_sonnet(system_prompt, prompt, temperature)
elif model == "Claude Haiku":
level = call_claude_3_haiku(system_prompt, prompt, temperature)
elif model == "Llama3 70B":
level = call_llama3_70b(system_prompt, prompt, temperature)
elif model == "Llama3 8B":
level = call_llama3_8b(system_prompt, prompt, temperature)
# elif model == "Claude Opus":
# level = call_claude_3_opus(system_prompt, prompt, temperature)
elif model == "Command R Plus":
level = call_command_r_plus(system_prompt, prompt, temperature)
elif model == "Command R":
level = call_command_r(system_prompt, prompt, temperature)
else:
raise ValueError("Invalid model")
# level = call_claude_3_sonnet(system_prompt, prompt)
cleaned_level = clean_level_data(level)
raw_level_text = get_raw_text(cleaned_level)
# print(cleaned_level)
filename = make_html_file(raw_level_text)
img = convert_level_to_png(cleaned_level, mario_lm.tokenizer)[0]
prompt_image = add_prompt_to_image(img, prompt)
gradio_html = f"""<div>
<iframe width=612 height=612 style="margin: 0 auto" src="static/{filename}"></iframe>
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
</div>"""
return [prompt_image, gradio_html]
with gr.Blocks().queue() as demo:
gr.Markdown(
"""## Playable demo of MarioGPT - Amazon Bedrock Edition
This is a demo of MarioGPT, a generative AI model trained on creating levels of the original Super Mario. By leveraging Amazon Bedrock, we can just a prompt to design a level instead of using the original deep learning model.
You can try it out by entering in a prompt and clicking `Generate level` to play!!! You can also edit the system prompt, the model, and the temperature.
### Resources:
* [Blog Post](https://community.aws/content/2fVi2aodJJtY996luEq3U9aijp4/super-mario-bros-the-llm-levels---generate-levels-with-a-prompt)
* [Code](https://huggingface.co/spaces/banjtheman/mariobedrock/tree/main)
* [Original Paper](https://arxiv.org/abs/2302.05981)
* [Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/service_code_examples.html?trk=2403b700-9ee9-49e8-aed8-411dea5cf5ae&sc_channel=el)
"""
)
with gr.Tabs():
with gr.TabItem("Prompt Settings"):
with gr.Accordion(label="System Prompt", open=False):
# temperature = gr.Number(
# value=2.0,
# label="temperature: Increase these for more diverse, but lower quality, generations",
# )
system_prompt = gr.TextArea(
value=system_prompt_text,
label="Enter your MarioGPT System prompt. ex: 'As an esteemed level designer renowned for creating some of the top 100 levels in Super Mario Maker...'",
)
text_prompt = gr.Textbox(
value="Generate a level with a few pipes, many coins. make sure there are around 10 enemies. Make sure there is a ground path Mario can walk on",
label="Enter your MarioGPT prompt. ex: 'Generate a level with a few pipes, many coins. make sure there are around 10 enemies. Make sure there is a ground path Mario can walk on'",
)
model = gr.Radio(
[
# "Claude Opus",
"Claude Sonnet",
"Claude Haiku",
"Llama3 70B",
"Llama3 8B",
"Command R Plus",
# "Command R", # taking too long to generate
],
label="Select Model",
value="Claude Opus",
)
with gr.Accordion(label="Advanced settings", open=False):
# temperature = gr.Number(
# value=0.7,
# label="temperature: Increase for more randomness",
# minimum=0.0,
# maximum=1.0,
# )
temperature = gr.Slider(
value=0.7,
minimum=0.0,
maximum=1.0,
step=0.1,
label="Temperature: Increase for more randomness",
)
btn = gr.Button("Generate level")
with gr.Row():
with gr.Group():
level_play = gr.HTML()
level_image = gr.Image()
btn.click(
fn=generate,
inputs=[
# temperature,
# level_size,
model,
text_prompt,
temperature,
system_prompt,
],
outputs=[level_image, level_play],
)
gr.Examples(
examples=[
[
"Claude Sonnet",
"Generate a fun level, make sure Mario will have a good time!!!",
1.0,
],
[
"Llama3 70B",
"Design a level with blocks arranged in a pyramid-like shape, with coins scattered around the base and Goombas guarding the top. Have a pipe at the top. Have a path of blocks for Mario to walk on",
0.7,
],
[
"Command R Plus",
"Make a simple level that has no enemies, but lots and lots of coins. Lots of blocks for Mario to walk on.",
0.7,
],
],
inputs=[model, text_prompt, temperature, system_prompt_text],
outputs=[level_image, level_play],
fn=generate,
cache_examples=False,
)
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")
uvicorn.run(app, host="0.0.0.0", port=7860)