Spaces:
Sleeping
Sleeping
File size: 1,472 Bytes
e679d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import importlib
import sys
from typing import Dict
import ray
from lagent.schema import AgentMessage
from lagent.utils import load_class_from_string
class AsyncAgentRayActor:
def __init__(
self,
config: Dict,
num_gpus: int,
):
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
self.agent_actor = AsyncAgentActor.remote(**config)
async def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
response = await self.agent_actor.__call__.remote(
*message, session_id=session_id, **kwargs)
return response
class AgentRayActor:
def __init__(
self,
config: Dict,
num_gpus: int,
):
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
AgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
self.agent_actor = AgentActor.remote(**config)
def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
response = self.agent_actor.__call__.remote(
*message, session_id=session_id, **kwargs)
return ray.get(response)
|