File size: 4,133 Bytes
01e655b
 
 
 
 
 
 
 
 
 
 
 
da8d589
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8d589
01e655b
 
 
 
 
 
 
da8d589
 
 
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from fastapi import HTTPException
from pydantic import BaseModel
import torch
from modules.speaker import speaker_mgr
from modules.api import utils as api_utils
from modules.api.Api import APIManager


class CreateSpeaker(BaseModel):
    name: str
    gender: str
    describe: str
    tensor: list = None
    seed: int = None


class UpdateSpeaker(BaseModel):
    id: str
    name: str
    gender: str
    describe: str
    tensor: list


class SpeakerDetail(BaseModel):
    id: str
    with_emb: bool = False


class SpeakersUpdate(BaseModel):
    speakers: list


def setup(app: APIManager):

    @app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
    async def list_speakers():
        return {
            "message": "ok",
            "data": [spk.to_json() for spk in speaker_mgr.list_speakers()],
        }

    @app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
    async def update_speakers(request: SpeakersUpdate):
        for spk in request.speakers:
            speaker = speaker_mgr.get_speaker_by_id(spk["id"])
            if speaker is None:
                raise HTTPException(
                    status_code=404, detail=f"Speaker not found: {spk['id']}"
                )
            speaker.name = spk.get("name", speaker.name)
            speaker.gender = spk.get("gender", speaker.gender)
            speaker.describe = spk.get("describe", speaker.describe)
            if (
                spk.get("tensor")
                and isinstance(spk["tensor"], list)
                and len(spk["tensor"]) > 0
            ):
                # number array => Tensor
                speaker.emb = torch.tensor(spk["tensor"])
        speaker_mgr.save_all()
        return {"message": "ok", "data": None}

    @app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
    async def create_speaker(request: CreateSpeaker):
        if (
            request.tensor
            and isinstance(request.tensor, list)
            and len(request.tensor) > 0
        ):
            # from tensor
            tensor = torch.tensor(request.tensor)
            speaker = speaker_mgr.create_speaker_from_tensor(
                tensor=tensor,
                name=request.name,
                gender=request.gender,
                describe=request.describe,
            )
        elif request.seed:
            # from seed
            speaker = speaker_mgr.create_speaker_from_seed(
                seed=request.seed,
                name=request.name,
                gender=request.gender,
                describe=request.describe,
            )
        else:
            raise HTTPException(
                status_code=400, detail="Missing tensor or seed in request"
            )
        return {"message": "ok", "data": speaker.to_json()}

    @app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
    async def refresh_speakers():
        speaker_mgr.refresh_speakers()
        return {"message": "ok"}

    @app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
    async def update_speaker(request: UpdateSpeaker):
        speaker = speaker_mgr.get_speaker_by_id(request.id)
        if speaker is None:
            raise HTTPException(
                status_code=404, detail=f"Speaker not found: {request.id}"
            )
        speaker.name = request.name
        speaker.gender = request.gender
        speaker.describe = request.describe
        if (
            request.tensor
            and isinstance(request.tensor, list)
            and len(request.tensor) > 0
        ):
            # number array => Tensor
            speaker.emb = torch.tensor(request.tensor)
        speaker_mgr.update_speaker(speaker)
        return {"message": "ok"}

    @app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
    async def speaker_detail(request: SpeakerDetail):
        speaker = speaker_mgr.get_speaker_by_id(request.id)
        if speaker is None:
            raise HTTPException(status_code=404, detail="Speaker not found")
        return {"message": "ok", "data": speaker.to_json(with_emb=request.with_emb)}