Banjo Obayomi commited on
Commit
60ec2f7
1 Parent(s): 4e00891

add prompt to image

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. mario_gpt/font.ttf +0 -0
  3. mario_gpt/utils.py +26 -1
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import uuid
3
  from mario_gpt.lm import MarioLM
4
- from mario_gpt.utils import convert_level_to_png
5
 
6
  from fastapi import FastAPI
7
  from fastapi.staticfiles import StaticFiles
@@ -473,12 +473,13 @@ def generate(model, prompt, temperature, system_prompt=system_prompt_text):
473
 
474
  filename = make_html_file(raw_level_text)
475
  img = convert_level_to_png(cleaned_level, mario_lm.tokenizer)[0]
 
476
 
477
  gradio_html = f"""<div>
478
  <iframe width=612 height=612 style="margin: 0 auto" src="static/{filename}"></iframe>
479
  <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
480
  </div>"""
481
- return [img, gradio_html]
482
 
483
 
484
  with gr.Blocks().queue() as demo:
 
1
  import gradio as gr
2
  import uuid
3
  from mario_gpt.lm import MarioLM
4
+ from mario_gpt.utils import convert_level_to_png, add_prompt_to_image
5
 
6
  from fastapi import FastAPI
7
  from fastapi.staticfiles import StaticFiles
 
473
 
474
  filename = make_html_file(raw_level_text)
475
  img = convert_level_to_png(cleaned_level, mario_lm.tokenizer)[0]
476
+ prompt_image = add_prompt_to_image(img, prompt)
477
 
478
  gradio_html = f"""<div>
479
  <iframe width=612 height=612 style="margin: 0 auto" src="static/{filename}"></iframe>
480
  <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
481
  </div>"""
482
+ return [prompt_image, gradio_html]
483
 
484
 
485
  with gr.Blocks().queue() as demo:
mario_gpt/font.ttf ADDED
Binary file (211 kB). View file
 
mario_gpt/utils.py CHANGED
@@ -4,7 +4,8 @@ from typing import List, Union
4
 
5
  import numpy as np
6
  import torch
7
- from PIL import Image
 
8
 
9
  pt = os.path.dirname(os.path.realpath(__file__))
10
  TILE_DIR = os.path.join(pt, "data", "tiles")
@@ -17,6 +18,30 @@ def trim_level(level):
17
  return level
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def characterize(str_lists):
21
  return [list(s[::-1]) for s in str_lists]
22
 
 
4
 
5
  import numpy as np
6
  import torch
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ from textwrap import wrap
9
 
10
  pt = os.path.dirname(os.path.realpath(__file__))
11
  TILE_DIR = os.path.join(pt, "data", "tiles")
 
18
  return level
19
 
20
 
21
+ def add_prompt_to_image(img: Image.Image, prompt: str) -> Image.Image:
22
+ # Load a font for drawing the prompt text
23
+ font_path = os.path.join(pt, "font.ttf")
24
+ font = ImageFont.truetype(font_path, size=16)
25
+ # font = ImageFont.load_default()
26
+
27
+ # Create a drawing object
28
+ draw = ImageDraw.Draw(img)
29
+
30
+ # Wrap the prompt text if it's too long
31
+ max_width = img.width - 500 # Adjust the maximum width as desired
32
+ wrapped_text = "\n".join(wrap(prompt, width=max_width // font.getlength(" ")))
33
+
34
+ # Calculate the position to draw the prompt text
35
+ text_width, text_height = draw.textsize(prompt, font)
36
+ x = 10 # Adjust the x-coordinate as desired
37
+ y = text_height - 10 # Adjust the y-coordinate as desired
38
+
39
+ # Draw the prompt text on the image
40
+ draw.text((x, y), wrapped_text, font=font, fill=(255, 255, 255))
41
+
42
+ return img
43
+
44
+
45
  def characterize(str_lists):
46
  return [list(s[::-1]) for s in str_lists]
47