File size: 1,871 Bytes
53d9e1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import pipeline
from pydub import AudioSegment
import torch
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
import asyncio

input = "layout_detection_3min.flac"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"device {device}")
# audio = AudioSegment.from_file("./data/Layout Detection.m4a", "m4a")
# offset_seconds = 180 * 1000
# audio = audio[:offset_seconds]
# audio.export(input, format="flac")

device = "cuda:2" if torch.cuda.is_available() else "cpu"


# res = transcriber(input)
# res = pipe(input)
# print(res)
res=None
async def homepage(request):
    global res
    payload = await request.body()
    input = payload.decode("utf-8")
    response_q = asyncio.Queue()
    await request.app.model_queue.put((input, response_q))
    output = await response_q.get()
    res = output
    return JSONResponse(output)

async def show(request):
    global res
    print(f"resquest {request}")
    return JSONResponse(res)
async def server_loop(q):
    pipe = pipeline(
        "automatic-speech-recognition",
        model="openai/whisper-large",
        chunk_length_s=30,
        device=device,
    )

    pipe.model.config.forced_decoder_ids = (
        pipe.tokenizer.get_decoder_prompt_ids(
            language="zh",
            task="transcribe"
        )
    )
    while True:
        (input, response_q) = await q.get()
        print(f"input {input}")
        print(f"response_q {response_q}")
        out = pipe(input)
        await response_q.put(out)


app = Starlette(
    routes=[
        Route("/", homepage, methods=["POST"]),
        Route("/show", show, methods=["GET"])
    ],
)


@app.on_event("startup")
async def startup_event():
    q = asyncio.Queue()
    app.model_queue = q
    asyncio.create_task(server_loop(q))