Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import multiprocessing | |
import logging | |
from PIL import Image | |
import io | |
import base64 | |
import numpy as np | |
import gymnasium as gym | |
import os | |
from agent.checklist import generate_checklist | |
from agent.reward import get_ar_reward | |
from browser_agent import BrowserAgent | |
logger = logging.getLogger(__name__) | |
logger.setLevel('INFO') | |
templates_dir = Path(__file__).parent / "templates" | |
CSS_RM_CARDS: str = (templates_dir / "rm_cards.css").read_text() | |
CSS_TRAJECTORY: str = (templates_dir / "trajectory.css").read_text() | |
CARD_HTML_TEMPLATE: str = (templates_dir / "card.html").read_text() | |
RM_BASE_URL = os.environ['RM_BASE_URL'] | |
RM_MODEL_NAME = os.environ['RM_MODEL_NAME'] | |
def return_state(state, screenshot=None): | |
return state, None, None, screenshot, None | |
def run_agent(instruction: str, model_name: str = "gpt-4o", start_url: str = "about:blank", | |
use_html: bool = False, use_axtree: bool = True, use_screenshot: bool = False, max_steps: int = 20): | |
logger.info(f"Starting agent with instruction: {instruction}") | |
logger.info(f"Configuration: model={model_name}, start_url={start_url}") | |
trajectory = [] | |
trajectory_str = '' | |
agent = BrowserAgent( | |
model_name=model_name, | |
use_html=use_html, | |
use_axtree=use_axtree, | |
use_screenshot=use_screenshot | |
) | |
# Initialize BrowserGym environment | |
logger.info("Initializing BrowserGym environment") | |
yield return_state("## Initializing BrowserGym environment...", None) | |
env = gym.make( | |
"browsergym/openended", | |
task_kwargs={ | |
"start_url": start_url, | |
"goal": instruction, | |
}, | |
wait_for_user_message=True | |
) | |
obs, info = env.reset() | |
logger.info("Environment initialized") | |
# Send user instruction to the environment | |
logger.info("Sending user instruction to environment") | |
obs, reward, terminated, truncated, info = env.step({ | |
"type": "send_msg_to_user", | |
"message": instruction | |
}) | |
processed_obs = agent.obs_preprocessor(obs) | |
logger.info(f"Obs: {processed_obs.keys()}") | |
logger.info(f"axtree_txt: {processed_obs['axtree_txt']}") | |
yield return_state("## Generating checklist...", obs['som_screenshot']) | |
checklist = generate_checklist(intent=instruction, start_url=start_url, text_observation=processed_obs['axtree_txt']) | |
# yield initial state | |
current_screenshot = obs['som_screenshot'].copy() | |
yield "## Rollout actions from policy...", checklist, [], current_screenshot, trajectory.copy() | |
try: | |
step_count = 0 | |
while step_count < max_steps: | |
logger.info(f"Step {step_count}: Getting next action") | |
# Get next action from agent | |
candidates, _ = agent.get_action(processed_obs) | |
yield return_state(f"## Rewarding actions...", current_screenshot) | |
total_rewards, total_thoughts = get_ar_reward( | |
dataset=[ | |
{ | |
'text_observation': processed_obs['axtree_txt'], | |
'intent': instruction, | |
'trajectory': trajectory_str, | |
'current_url': processed_obs['open_pages_urls'][processed_obs['active_page_index'][0]], | |
'checklist': checklist, | |
'thought': cand['thought'], | |
'action': cand['action'], | |
} for cand in candidates | |
], | |
base_url=RM_BASE_URL, | |
model_name=RM_MODEL_NAME, | |
) | |
# process rewards | |
diff_reward = abs(max(total_rewards) - total_rewards[0]) # reward difference between actions with the highest reward and the most frequent. | |
if diff_reward <= 0.01: | |
logger.info(f"diff_reward: {diff_reward} -> most frequent action") | |
max_index = 0 # most frequent action | |
else: | |
logger.info(f"diff_reward: {diff_reward} -> highest reward") | |
max_index = total_rewards.index(max(total_rewards)) # highest reward | |
# sort by reward | |
sorted_indices = sorted(list(enumerate(total_rewards)), key=lambda x: (-1 if x[0] == max_index else 0, -x[1])) | |
new_order = [idx for idx, _ in sorted_indices] | |
candidates = [candidates[idx] for idx in new_order] | |
total_rewards = [total_rewards[idx] for idx in new_order] | |
total_thoughts = [total_thoughts[idx] for idx in new_order] | |
best_cand = candidates[0] | |
agent.action_history.append(best_cand['response']) | |
action = best_cand['action'] | |
# processing action | |
step_info = { | |
'thought': best_cand['thought'], | |
'action': action | |
} | |
current_cards = [{'thought': cand['thought'], 'action': cand['action'], 'feedback': feedback, 'reward': round(reward, 2)} for idx, (cand, reward, feedback) in enumerate(zip(candidates, total_rewards, total_thoughts))] | |
trajectory_str += f'THOUGHT {step_count+1}: {step_info["thought"]}\nACTION {step_count+1}: {step_info["action"]}\n\n' | |
# Execute action | |
logger.info(f"Step {step_count}: Executing action: {action}") | |
yield f"## Executing action: {action}", checklist, current_cards, current_screenshot, trajectory.copy() | |
if action.startswith('send_msg_to_user'): | |
terminated = True | |
truncated = False | |
else: | |
obs, reward, terminated, truncated, info = env.step(action) | |
trajectory.append((processed_obs['som_screenshot'], [{'action': cand['action'], 'reward': round(reward, 2)} for cand, reward in zip(candidates, total_rewards)])) | |
processed_obs = agent.obs_preprocessor(obs) | |
current_screenshot = processed_obs['som_screenshot'].copy() | |
while '\n\n' in step_info['thought']: | |
step_info['thought'] = step_info['thought'].replace('\n\n', '\n') | |
# trajectory에 numpy array 직접 저장 | |
logger.info(f"Step {step_count}: Saved screenshot and updated trajectory") | |
step_count += 1 | |
# yield by each step | |
yield "## Rollout actions from policy...", checklist, current_cards, current_screenshot, trajectory.copy() | |
if terminated or truncated: | |
logger.info(f"Episode ended: terminated={terminated}, truncated={truncated}") | |
yield return_state("## Episode ended", current_screenshot) | |
break | |
finally: | |
logger.info("Finished") | |
def run_agent_worker(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue): | |
"""Worker function that runs the agent in a separate process and puts results in a queue.""" | |
try: | |
for result in run_agent(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps): | |
return_queue.put(result) | |
except Exception as e: | |
logger.error(f"Error in agent worker process: {e}") | |
return_queue.put(("Error occurred in agent process", [], None, [])) | |
import traceback | |
traceback.print_exc() | |
finally: | |
# Signal that the process is done | |
return_queue.put(None) | |
def run_agent_wrapper(instruction, model_name="gpt-4o", start_url="about:blank", | |
use_html=False, use_axtree=True, use_screenshot=False, max_steps=20): | |
"""Wrapper function that runs the agent in a separate process and yields results.""" | |
return_queue = multiprocessing.Queue() | |
# Start the agent in a separate process | |
p = multiprocessing.Process( | |
target=run_agent_worker, | |
args=(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue) | |
) | |
p.daemon = True # Ensure process terminates when parent terminates | |
p.start() | |
# Get results from the queue and yield them | |
while True: | |
result = return_queue.get() | |
if result is None: # End signal | |
break | |
yield result | |
# Clean up | |
if p.is_alive(): | |
p.terminate() | |
p.join() | |
def process_run(instruction, model_name, start_url): | |
# Use the wrapper function instead of directly calling run_agent | |
trajectory_generator = run_agent_wrapper( | |
instruction, | |
model_name, | |
start_url, | |
use_html=False, | |
use_axtree=True, | |
use_screenshot=False | |
) | |
all_trajectory = [] | |
last_checklist_view, last_trajectory_html = None, None | |
for state, checklist_view, rm_cards, screenshot, trajectory in trajectory_generator: | |
if checklist_view is None: | |
yield state, screenshot, last_checklist_view, None, last_trajectory_html | |
continue | |
# Create HTML for reward model cards | |
rm_cards_html = f""" | |
<style> | |
{CSS_RM_CARDS} | |
</style> | |
<div class="rm-cards-container"> | |
""" | |
for idx, card in enumerate(rm_cards): | |
rm_cards_html += CARD_HTML_TEMPLATE.format( | |
additional_class='top-candidate' if idx == 0 else '', | |
k=idx+1, | |
suffix='(best)' if idx == 0 else '', | |
thought=card['thought'], | |
action=card['action'], | |
reward=card['reward'], | |
feedback=card['feedback'] | |
) | |
rm_cards_html += "</div>" | |
all_trajectory = trajectory | |
# Create HTML for trajectory display | |
trajectory_html = f""" | |
<style> | |
{CSS_TRAJECTORY} | |
</style> | |
<div class="trajectory-container"> | |
""" | |
for idx, (after_img, cands) in enumerate(all_trajectory): | |
# Convert image to base64 if needed | |
img = all_trajectory[idx][0] | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
if isinstance(img, Image.Image): | |
buffer = io.BytesIO() | |
img.save(buffer, format="JPEG") | |
img_str = base64.b64encode(buffer.getvalue()).decode() | |
img_src = f"data:image/jpeg;base64,{img_str}" | |
else: | |
img_src = img | |
trajectory_html += f""" | |
<div class="step-container"> | |
<div class="step-header">Step {idx + 1}</div> | |
<div class="step-content"> | |
<div class="step-image"> | |
<img src="{img_src}" alt="Browser state"> | |
</div> | |
<div class="step-info"> | |
<div class="box-title">Action Candidates:</div> | |
<div class="action-candidates"> | |
""" | |
# Display all candidates for this step | |
for i, cand in enumerate(cands): | |
action = cand['action'] | |
reward = cand['reward'] | |
trajectory_html += f""" | |
<div class="candidate-box{' selected' if i == 0 else ''}"> | |
<div class="box-title"> | |
Action {i+1}{' (Selected)' if i == 0 else ''} | |
<span class="reward-text">Reward: {reward}</span> | |
</div> | |
<pre>{action}</pre> | |
</div> | |
""" | |
trajectory_html += """ | |
</div> | |
</div> | |
</div> | |
</div> | |
""" | |
trajectory_html += "</div>" | |
last_checklist_view, last_trajectory_html = checklist_view, trajectory_html | |
yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html | |
yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html | |