File size: 2,752 Bytes
76e433a
 
 
 
 
 
 
 
58ac03b
 
 
 
76e433a
 
 
58ac03b
 
 
76e433a
 
 
 
 
58ac03b
 
 
 
 
 
 
 
76e433a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58ac03b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import pathlib
import torch
from fairseq2.assets import InProcAssetMetadataProvider, asset_store
from seamless_communication.inference import Translator

CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
if not CHECKPOINTS_PATH.exists():
    # from huggingface_hub import snapshot_download
    # snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
    raise FileNotFoundError(f"Checkpoint path {CHECKPOINTS_PATH} does not exist")

asset_store.env_resolvers.clear()
asset_store.env_resolvers.append(lambda: "demo")
demo_metadata = [
    # https://github.com/facebookresearch/seamless_communication/blob/dd67e71317d66752ef16cf21bd842ca3273244c9/src/seamless_communication/cards/seamlessM4T_v2_large.yaml#L10
    # char_tokenizer: "https://huggingface.co/facebook/seamless-m4t-v2-large/resolve/main/spm_char_lang38_tc.model"
    # checkpoint: "https://huggingface.co/facebook/seamless-m4t-v2-large/resolve/main/seamlessM4T_v2_large.pt"
    {
        "name": "seamlessM4T_v2_large@demo",
        "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
        "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
    },
    # https://github.com/facebookresearch/seamless_communication/blob/dd67e71317d66752ef16cf21bd842ca3273244c9/src/seamless_communication/cards/unity_nllb-100.yaml#L9C1-L9C93
    # tokenizer: "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
    {
        "name": "unity_nllb-100@demo",
        "tokenizer": f"file://{CHECKPOINTS_PATH}/tokenizer.model",
    },
    # https://github.com/facebookresearch/seamless_communication/blob/dd67e71317d66752ef16cf21bd842ca3273244c9/src/seamless_communication/cards/vocoder_v2.yaml#L10
    # checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/vocoder_v2.pt"
    {
        "name": "vocoder_v2@demo",
        "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
    },
]
asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    dtype = torch.float16
else:
    device = torch.device("cpu")
    dtype = torch.float32

translator = Translator(
    model_name_or_card="seamlessM4T_v2_large",
    vocoder_name_or_card="vocoder_v2",
    device=device,
    dtype=dtype,
    apply_mintox=True,
)

if __name__ == '__main__':
    input_text = "Hello, how are you today?"
    source_language_code = "eng"
    target_language_code = "zsm"

    result = translator.predict(
        input=input_text,
        task_str="T2TT",
        src_lang=source_language_code,
        tgt_lang=target_language_code,
    )
    print(str(result[0]))