Your Name
Add Gradio application files
2ad4d00
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
import os
from PIL import Image
import logging
import json
import numpy as np
import csv
import gymnasium
from vizdoom import gymnasium_wrapper # This import is needed to register the env
DATASET_DIR = "gamelogs"
FRAMES_DIR = os.path.join(DATASET_DIR, "frames")
os.makedirs(FRAMES_DIR, exist_ok=True)
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super(NpEncoder, self).default(obj)
class GameNGenCallback(BaseCallback):
def __init__(self, verbose: bool, save_path: str):
super(GameNGenCallback, self).__init__(verbose)
self.save_path = save_path
self.frame_log = open(os.path.join(self.save_path, "metadata.csv"), mode="w", newline="")
self.csv_writer = csv.writer(self.frame_log)
# CSV Header
self.csv_writer.writerow(["frame_id", "action"])
def _on_step(self) -> bool:
frame_id = self.n_calls
key = f"{frame_id:09d}"
try:
obs_dict = self.locals["new_obs"]
# The observation from the callback is in Channels-First format (C, H, W)
frame_data = obs_dict['screen'][0]
action = self.locals["actions"][0]
# --- DEFINITIVE FIX ---
# Check if the frame is in the expected Channels-First format (C, H, W).
# A valid RGB image will have 3 channels in its first dimension.
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# Pillow's fromarray function needs the image in Channels-Last format (H, W, C).
# We must transpose the axes from (C, H, W) to (H, W, C).
transposed_frame = np.transpose(frame_data, (1, 2, 0))
image = Image.fromarray(transposed_frame)
image.save(os.path.join(FRAMES_DIR, f"frame_{key}.png"))
json_action = json.dumps(action, cls=NpEncoder)
self.csv_writer.writerow([key, json_action])
else:
# This will now correctly catch the junk frames from terminal states.
logging.warning(f"Skipping corrupted frame {key} with invalid shape: {frame_data.shape}")
except Exception as e:
# This will now only catch truly unexpected errors.
logging.error(f"Could not process or save frame {key} due to an unexpected error: {e}")
return True
def _on_training_end(self) -> None:
self.frame_log.close()
# --- Main script ---
logging.basicConfig(level=logging.INFO)
# Create the VizDoom environment. No wrappers are needed.
env = gymnasium.make("VizdoomHealthGatheringSupreme-v0")
callback = GameNGenCallback(verbose=True, save_path=DATASET_DIR)
model = PPO(
"MultiInputPolicy",
env,
verbose=1,
)
model.learn(total_timesteps=2_000_000, callback=callback)
env.close()