seamless-streaming / seamless_server /src /simuleval_agent_directory.py
Anna Sun
add dual non-expr/expressive agent, install sc from github
1143e8d
raw
history blame
6 kB
# Creates a directory in which to look up available agents
import os
from typing import List, Optional
from src.simuleval_transcoder import SimulevalTranscoder
import json
import logging
logger = logging.getLogger("socketio_server_pubsub")
# fmt: off
M4T_P0_LANGS = [
"eng",
"arb", "ben", "cat", "ces", "cmn", "cym", "dan",
"deu", "est", "fin", "fra", "hin", "ind", "ita",
"jpn", "kor", "mlt", "nld", "pes", "pol", "por",
"ron", "rus", "slk", "spa", "swe", "swh", "tel",
"tgl", "tha", "tur", "ukr", "urd", "uzn", "vie",
]
# fmt: on
class NoAvailableAgentException(Exception):
pass
class AgentWithInfo:
def __init__(
self,
agent,
name: str,
modalities: List[str],
target_langs: List[str],
# Supported dynamic params are defined in StreamingTypes.ts
dynamic_params: List[str] = [],
description="",
has_expressive: Optional[bool] = None,
):
self.agent = agent
self.has_expressive = has_expressive
self.name = name
self.description = description
self.modalities = modalities
self.target_langs = target_langs
self.dynamic_params = dynamic_params
def get_capabilities_for_json(self):
return {
"name": self.name,
"description": self.description,
"modalities": self.modalities,
"targetLangs": self.target_langs,
"dynamicParams": self.dynamic_params,
}
@classmethod
def load_from_json(cls, config: str):
"""
Takes in JSON array of models to load in, e.g.
[{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]},
{"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}]
"""
configs = json.loads(config)
agents = []
for config in configs:
agent = SimulevalTranscoder.build_agent(config["name"])
agents.append(
AgentWithInfo(
agent=agent,
name=config["name"],
modalities=config["modalities"],
target_langs=config["targetLangs"],
)
)
return agents
class SimulevalAgentDirectory:
# Available models. These are the directories where the models can be found, and also serve as an ID for the model.
seamless_streaming_agent = "SeamlessStreaming"
seamless_agent = "Seamless"
def __init__(self):
self.agents = []
self.did_build_and_add_agents = False
def add_agent(self, agent: AgentWithInfo):
self.agents.append(agent)
def build_agent_if_available(self, model_id, config_name=None):
agent = None
try:
if config_name is not None:
agent = SimulevalTranscoder.build_agent(
model_id,
config_name=config_name,
)
else:
agent = SimulevalTranscoder.build_agent(
model_id,
)
except Exception as e:
from fairseq2.assets.error import AssetError
logger.warning("Failed to build agent %s: %s" % (model_id, e))
if isinstance(e, AssetError):
logger.warning(
"Please download gated assets and set `gated_model_dir` in the config"
)
raise e
return agent
def build_and_add_agents(self, models_override=None):
if self.did_build_and_add_agents:
return
if models_override is not None:
agent_infos = AgentWithInfo.load_from_json(models_override)
for agent_info in agent_infos:
self.add_agent(agent_info)
else:
s2s_agent = None
if os.environ.get("USE_EXPRESSIVE_MODEL"):
logger.info("Building expressive model...")
s2s_agent = self.build_agent_if_available(
SimulevalAgentDirectory.seamless_agent,
config_name="vad_s2st_sc_24khz_main.yaml",
)
has_expressive = True
else:
logger.info("Building non-expressive model...")
s2s_agent = self.build_agent_if_available(
SimulevalAgentDirectory.seamless_streaming_agent,
config_name="vad_s2st_sc_main.yaml",
)
has_expressive = False
if s2s_agent:
self.add_agent(
AgentWithInfo(
agent=s2s_agent,
name=SimulevalAgentDirectory.seamless_streaming_agent,
modalities=["s2t", "s2s"],
target_langs=M4T_P0_LANGS,
dynamic_params=["expressive"],
description="multilingual expressive model that supports S2S and S2T",
has_expressive=has_expressive,
)
)
if len(self.agents) == 0:
logger.error(
"No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory."
)
self.did_build_and_add_agents = True
def get_agent(self, name):
for agent in self.agents:
if agent.name == name:
return agent
return None
def get_agent_or_throw(self, name):
agent = self.get_agent(name)
if agent is None:
raise NoAvailableAgentException("No agent found with name= %s" % (name))
return agent
def get_agents_capabilities_list_for_json(self):
return [agent.get_capabilities_for_json() for agent in self.agents]