import io | |
import os | |
from PIL import Image | |
def fig2img(fig): | |
"""Convert a Matplotlib figure to a PIL Image and return it""" | |
# Save figure to a temporary buffer. | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
# Create PIL image from buffer | |
img = Image.open(buf) | |
return img | |
def create_gif_from_frames(frames, filename): | |
"""Create a GIF from a list of PIL Image frames""" | |
# Create output directory if it doesn't exist | |
os.makedirs('episode_gifs', exist_ok=True) | |
# Save the frames as GIF | |
frames[0].save( | |
f'{filename}', | |
save_all=True, | |
append_images=frames[1:], | |
duration=200, # Duration for each frame in milliseconds | |
loop=0 | |
) | |