| | |
| | |
| | |
| |
|
| | import argparse |
| | import asyncio |
| | from dataclasses import dataclass |
| | import random |
| | import os |
| | from pathlib import Path |
| | import tarfile |
| | import time |
| | import secrets |
| | import sys |
| |
|
| | import aiohttp |
| | from aiohttp import web |
| | from huggingface_hub import hf_hub_download |
| | import numpy as np |
| | import sentencepiece |
| | import sphn |
| | import torch |
| |
|
| |
|
| | from .client_utils import make_log |
| | from .models import loaders, MimiModel, LMModel, LMGen |
| |
|
| |
|
| | def log(level: str, msg: str): |
| | print(make_log(level, msg)) |
| |
|
| |
|
| | def seed_all(seed): |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.backends.cudnn.deterministic = False |
| | torch.backends.cudnn.benchmark = False |
| |
|
| |
|
| | @dataclass |
| | class ServerState: |
| | mimi: MimiModel |
| | text_tokenizer: sentencepiece.SentencePieceProcessor |
| | lm_gen: LMGen |
| | lock: asyncio.Lock |
| |
|
| | def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, |
| | lm: LMModel, device: str | torch.device): |
| | self.mimi = mimi |
| | self.text_tokenizer = text_tokenizer |
| | self.lm_gen = LMGen(lm) |
| |
|
| | self.device = device |
| | self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) |
| | self.lock = asyncio.Lock() |
| |
|
| | self.mimi.streaming_forever(1) |
| | self.lm_gen.streaming_forever(1) |
| |
|
| | def warmup(self): |
| | for chunk in range(4): |
| | chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) |
| | codes = self.mimi.encode(chunk) |
| | for c in range(codes.shape[-1]): |
| | tokens = self.lm_gen.step(codes[:, :, c: c + 1]) |
| | if tokens is None: |
| | continue |
| | _ = self.mimi.decode(tokens[:, 1:]) |
| | torch.cuda.synchronize() |
| |
|
| | async def handle_chat(self, request): |
| | ws = web.WebSocketResponse() |
| | await ws.prepare(request) |
| |
|
| | async def recv_loop(): |
| | nonlocal close |
| | try: |
| | async for message in ws: |
| | if message.type == aiohttp.WSMsgType.ERROR: |
| | log("error", f"{ws.exception()}") |
| | break |
| | elif message.type == aiohttp.WSMsgType.CLOSED: |
| | break |
| | elif message.type != aiohttp.WSMsgType.BINARY: |
| | log("error", f"unexpected message type {message.type}") |
| | continue |
| | message = message.data |
| | if not isinstance(message, bytes): |
| | log("error", f"unsupported message type {type(message)}") |
| | continue |
| | if len(message) == 0: |
| | log("warning", "empty message") |
| | continue |
| | kind = message[0] |
| | if kind == 1: |
| | payload = message[1:] |
| | opus_reader.append_bytes(payload) |
| | else: |
| | log("warning", f"unknown message kind {kind}") |
| | finally: |
| | close = True |
| | log("info", "connection closed") |
| |
|
| | async def opus_loop(): |
| | all_pcm_data = None |
| |
|
| | while True: |
| | if close: |
| | return |
| | await asyncio.sleep(0.001) |
| | pcm = opus_reader.read_pcm() |
| | if pcm.shape[-1] == 0: |
| | continue |
| | if all_pcm_data is None: |
| | all_pcm_data = pcm |
| | else: |
| | all_pcm_data = np.concatenate((all_pcm_data, pcm)) |
| | while all_pcm_data.shape[-1] >= self.frame_size: |
| | be = time.time() |
| | chunk = all_pcm_data[: self.frame_size] |
| | all_pcm_data = all_pcm_data[self.frame_size:] |
| | chunk = torch.from_numpy(chunk) |
| | chunk = chunk.to(device=self.device)[None, None] |
| | codes = self.mimi.encode(chunk) |
| | for c in range(codes.shape[-1]): |
| | tokens = self.lm_gen.step(codes[:, :, c: c + 1]) |
| | if tokens is None: |
| | continue |
| | assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 |
| | main_pcm = self.mimi.decode(tokens[:, 1:]) |
| | main_pcm = main_pcm.cpu() |
| | opus_writer.append_pcm(main_pcm[0, 0].numpy()) |
| | text_token = tokens[0, 0, 0].item() |
| | if text_token not in (0, 3): |
| | _text = self.text_tokenizer.id_to_piece(text_token) |
| | _text = _text.replace("▁", " ") |
| | msg = b"\x02" + bytes(_text, encoding="utf8") |
| | log("info", f"text token '{_text}'") |
| | await ws.send_bytes(msg) |
| | log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms") |
| |
|
| | async def send_loop(): |
| | while True: |
| | if close: |
| | return |
| | await asyncio.sleep(0.001) |
| | msg = opus_writer.read_bytes() |
| | if len(msg) > 0: |
| | await ws.send_bytes(b"\x01" + msg) |
| |
|
| | log("info", "accepted connection") |
| | close = False |
| | async with self.lock: |
| | opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) |
| | opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) |
| | self.mimi.reset_streaming() |
| | self.lm_gen.reset_streaming() |
| | |
| | await ws.send_bytes(b"\x00") |
| | await asyncio.gather(opus_loop(), recv_loop(), send_loop()) |
| | log("info", "done with connection") |
| | return ws |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--host", default="localhost", type=str) |
| | parser.add_argument("--port", default=8998, type=int) |
| | parser.add_argument("--static", type=str) |
| | parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.') |
| | parser.add_argument("--gradio-tunnel-token", |
| | help='Provide a custom (secret) token here to keep getting the same URL.') |
| |
|
| | parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") |
| | parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") |
| | parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") |
| | parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, |
| | help="HF repo to look into, defaults Moshiko. " |
| | "Use this to select a different pre-trained model.") |
| | parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") |
| |
|
| | args = parser.parse_args() |
| | seed_all(42424242) |
| |
|
| | setup_tunnel = None |
| | tunnel_token = '' |
| | if args.gradio_tunnel: |
| | try: |
| | from gradio import networking |
| | except ImportError: |
| | log("error", "Cannot find gradio which is required to activate a tunnel. " |
| | "Please install with `pip install gradio`.") |
| | sys.exit(1) |
| | setup_tunnel = networking.setup_tunnel |
| | if args.gradio_tunnel_token is None: |
| | tunnel_token = secrets.token_urlsafe(32) |
| | else: |
| | tunnel_token = args.gradio_tunnel_token |
| |
|
| | log("info", "loading mimi") |
| | if args.mimi_weight is None: |
| | args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) |
| | mimi = loaders.get_mimi(args.mimi_weight, args.device) |
| | log("info", "mimi loaded") |
| |
|
| | if args.tokenizer is None: |
| | args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME) |
| | text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) |
| |
|
| | log("info", "loading moshi") |
| | if args.moshi_weight is None: |
| | args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME) |
| | lm = loaders.get_moshi_lm(args.moshi_weight, args.device) |
| | log("info", "moshi loaded") |
| |
|
| | state = ServerState(mimi, text_tokenizer, lm, args.device) |
| | log("info", "warming up the model") |
| | state.warmup() |
| | app = web.Application() |
| | app.router.add_get("/api/chat", state.handle_chat) |
| | static_path: None | str = None |
| | if args.static is None: |
| | log("info", "retrieving the static content") |
| | dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "dist.tgz") |
| | dist_tgz = Path(dist_tgz) |
| | dist = dist_tgz.parent / "dist" |
| | if not dist.exists(): |
| | with tarfile.open(dist_tgz, "r:gz") as tar: |
| | tar.extractall(path=dist_tgz.parent) |
| | static_path = str(dist) |
| | elif args.static != "none": |
| | |
| | static_path = args.static |
| | if static_path is not None: |
| | async def handle_root(_): |
| | return web.FileResponse(os.path.join(static_path, "index.html")) |
| |
|
| | log("info", f"serving static content from {static_path}") |
| | app.router.add_get("/", handle_root) |
| | app.router.add_static( |
| | "/", path=static_path, follow_symlinks=True, name="static" |
| | ) |
| | log("info", f"Access the Web UI directly at http://{args.host}:{args.port}") |
| | if setup_tunnel is not None: |
| | tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) |
| | log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.") |
| | log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.") |
| | web.run_app(app, port=args.port) |
| |
|
| |
|
| | with torch.no_grad(): |
| | main() |
| |
|