Spaces:
Paused
Paused
# Creates a directory in which to look up available agents | |
import os | |
from typing import List, Optional | |
from seamless.simuleval_transcoder import SimulevalTranscoder | |
import json | |
import logging | |
logger = logging.getLogger("gunicorn") | |
# fmt: off | |
M4T_P0_LANGS = [ | |
"eng", | |
"arb", "ben", "cat", "ces", "cmn", "cym", "dan", | |
"deu", "est", "fin", "fra", "hin", "ind", "ita", | |
"jpn", "kor", "mlt", "nld", "pes", "pol", "por", | |
"ron", "rus", "slk", "spa", "swe", "swh", "tel", | |
"tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", | |
] | |
# fmt: on | |
class NoAvailableAgentException(Exception): | |
pass | |
class AgentWithInfo: | |
def __init__( | |
self, | |
agent, | |
name: str, | |
modalities: List[str], | |
target_langs: List[str], | |
# Supported dynamic params are defined in StreamingTypes.ts | |
dynamic_params: List[str] = [], | |
description="", | |
has_expressive: Optional[bool] = None, | |
): | |
self.agent = agent | |
self.has_expressive = has_expressive | |
self.name = name | |
self.description = description | |
self.modalities = modalities | |
self.target_langs = target_langs | |
self.dynamic_params = dynamic_params | |
def get_capabilities_for_json(self): | |
return { | |
"name": self.name, | |
"description": self.description, | |
"modalities": self.modalities, | |
"targetLangs": self.target_langs, | |
"dynamicParams": self.dynamic_params, | |
} | |
def load_from_json(cls, config: str): | |
""" | |
Takes in JSON array of models to load in, e.g. | |
[{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}, | |
{"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}] | |
""" | |
configs = json.loads(config) | |
agents = [] | |
for config in configs: | |
agent = SimulevalTranscoder.build_agent(config["name"]) | |
agents.append( | |
AgentWithInfo( | |
agent=agent, | |
name=config["name"], | |
modalities=config["modalities"], | |
target_langs=config["targetLangs"], | |
) | |
) | |
return agents | |
class SimulevalAgentDirectory: | |
# Available models. These are the directories where the models can be found, and also serve as an ID for the model. | |
seamless_streaming_agent = "SeamlessStreaming" | |
seamless_agent = "Seamless" | |
def __init__(self): | |
self.agents = [] | |
self.did_build_and_add_agents = False | |
def add_agent(self, agent: AgentWithInfo): | |
self.agents.append(agent) | |
def build_agent_if_available(self, model_id, config_name=None): | |
agent = None | |
try: | |
if config_name is not None: | |
agent = SimulevalTranscoder.build_agent( | |
model_id, | |
config_name=config_name, | |
) | |
else: | |
agent = SimulevalTranscoder.build_agent( | |
model_id, | |
) | |
except Exception as e: | |
from fairseq2.assets.error import AssetError | |
logger.warning("Failed to build agent %s: %s" % (model_id, e)) | |
if isinstance(e, AssetError): | |
logger.warning( | |
"Please download gated assets and set `gated_model_dir` in the config" | |
) | |
raise e | |
return agent | |
def build_and_add_agents(self, models_override=None): | |
if self.did_build_and_add_agents: | |
return | |
if models_override is not None: | |
agent_infos = AgentWithInfo.load_from_json(models_override) | |
for agent_info in agent_infos: | |
self.add_agent(agent_info) | |
else: | |
s2s_agent = None | |
if os.environ.get("USE_EXPRESSIVE_MODEL", "0") == "1": | |
logger.info("Building expressive model...") | |
s2s_agent = self.build_agent_if_available( | |
SimulevalAgentDirectory.seamless_agent, | |
config_name="vad_s2st_sc_24khz_main.yaml", | |
) | |
has_expressive = True | |
else: | |
logger.info("Building non-expressive model...") | |
s2s_agent = self.build_agent_if_available( | |
SimulevalAgentDirectory.seamless_streaming_agent, | |
config_name="vad_s2st_sc_main.yaml", | |
) | |
has_expressive = False | |
if s2s_agent: | |
self.add_agent( | |
AgentWithInfo( | |
agent=s2s_agent, | |
name=SimulevalAgentDirectory.seamless_streaming_agent, | |
modalities=["s2t", "s2s"], | |
target_langs=M4T_P0_LANGS, | |
dynamic_params=["expressive"], | |
description="multilingual expressive model that supports S2S and S2T", | |
has_expressive=has_expressive, | |
) | |
) | |
if len(self.agents) == 0: | |
logger.error( | |
"No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory." | |
) | |
self.did_build_and_add_agents = True | |
def get_agent(self, name): | |
for agent in self.agents: | |
if agent.name == name: | |
return agent | |
return None | |
def get_agent_or_throw(self, name): | |
agent = self.get_agent(name) | |
if agent is None: | |
raise NoAvailableAgentException("No agent found with name= %s" % (name)) | |
return agent | |
def get_agents_capabilities_list_for_json(self): | |
return [agent.get_capabilities_for_json() for agent in self.agents] | |