rcastriotta commited on
Commit
17ff85c
1 Parent(s): e748991

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +162 -0
main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Response
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from audio_separator.separator import Separator
5
+ import ffmpeg
6
+ from datetime import datetime
7
+ import logging
8
+ import os
9
+ import uuid
10
+ from youtube_transcript_api import YouTubeTranscriptApi
11
+ import asyncio
12
+ from fastapi.concurrency import run_in_threadpool
13
+ from concurrent.futures import ThreadPoolExecutor
14
+
15
+ app = FastAPI()
16
+ tmp_directory = "tmp"
17
+ separator = Separator(output_dir=tmp_directory, log_level=logging.INFO)
18
+ logging.getLogger().setLevel(logging.INFO)
19
+ separator.load_model("UVR-MDX-NET-Inst_Main.onnx")
20
+
21
+ executor = ThreadPoolExecutor(max_workers=8)
22
+
23
+
24
+ class IsolationRequest(BaseModel):
25
+ url: str
26
+ start_time: float
27
+ duration_seconds: float
28
+
29
+
30
+ @app.post("/isolate")
31
+ async def isolate_voice(request: IsolationRequest):
32
+ media_url = request.url
33
+ start_seconds = request.start_time
34
+ duration_seconds = request.duration_seconds
35
+ try:
36
+ extracted_audio_path = f"{tmp_directory}/{uuid.uuid4()}.wav"
37
+
38
+ # TODO switch to CUDA
39
+ await extract_audio(
40
+ media_url, start_seconds, duration_seconds, extracted_audio_path
41
+ )
42
+
43
+ (
44
+ primary_stem_output_path,
45
+ secondary_stem_output_path,
46
+ ) = await asyncio.get_event_loop().run_in_executor(
47
+ executor,
48
+ separator.separate,
49
+ extracted_audio_path,
50
+ )
51
+
52
+ with open(f"{tmp_directory}/{primary_stem_output_path}", "rb") as f:
53
+ isolated_audio_data = f.read()
54
+
55
+ except Exception as e:
56
+ logging.error(f"An error occurred: {str(e)}")
57
+ raise HTTPException(
58
+ status_code=500, detail="An error occurred during vocal isolation"
59
+ )
60
+
61
+ finally:
62
+ try:
63
+ os.remove(extracted_audio_path)
64
+ os.remove(f"{tmp_directory}/{primary_stem_output_path}")
65
+ os.remove(f"{tmp_directory}/{secondary_stem_output_path}")
66
+ except OSError as e:
67
+ logging.warning(
68
+ f"Error occurred while cleaning up temporary files: {str(e)}"
69
+ )
70
+
71
+ return Response(content=isolated_audio_data, media_type="audio/wav")
72
+
73
+
74
+ async def extract_audio(
75
+ media_url: str, start_seconds: float, duration_seconds: float, output_path: str
76
+ ):
77
+ start_time = datetime.now()
78
+ await asyncio.get_event_loop().run_in_executor(
79
+ None, # Uses the default executor
80
+ lambda: ffmpeg.input(media_url, ss=start_seconds)
81
+ .output(output_path, format="wav", t=duration_seconds)
82
+ .global_args("-loglevel", "error", "-hide_banner")
83
+ .global_args("-nostats")
84
+ .run(),
85
+ )
86
+
87
+ end_time = datetime.now()
88
+ logging.info(
89
+ f"Audio extraction took {(end_time - start_time).total_seconds()} seconds"
90
+ )
91
+
92
+
93
+ def scrape_subtitles(video_id, translate_to, translate_from):
94
+ transcript_list = YouTubeTranscriptApi.list_transcripts(
95
+ video_id,
96
+ )
97
+
98
+ # see if translation already exists
99
+ try:
100
+ return transcript_list.find_transcript([translate_to]).fetch()
101
+ except:
102
+ pass
103
+
104
+ # find transcription in video language
105
+ try:
106
+ return (
107
+ transcript_list.find_transcript([translate_from])
108
+ .translate(translate_to)
109
+ .fetch()
110
+ )
111
+
112
+ except:
113
+ pass
114
+
115
+ # search for any other translatable languages
116
+ for transcript in transcript_list:
117
+ try:
118
+ return transcript.translate(translate_to).fetch()
119
+ except:
120
+ continue
121
+
122
+ return None
123
+
124
+
125
+ def format_language_code(lang: str) -> str:
126
+ mapping = {
127
+ "he": "iw",
128
+ "zh": "zh-Hans",
129
+ "zh-TW": "zh-Hant",
130
+ }
131
+ return mapping.get(lang, lang.split("-")[0])
132
+
133
+
134
+ class SubtitleRequest(BaseModel):
135
+ video_id: str
136
+ translate_to: str
137
+ translate_from: str
138
+
139
+
140
+ @app.post("/subtitles")
141
+ async def get_subtitles(request: SubtitleRequest):
142
+ try:
143
+ subtitles = await run_in_threadpool(
144
+ scrape_subtitles,
145
+ request.video_id,
146
+ format_language_code(request.translate_to),
147
+ format_language_code(request.translate_from),
148
+ )
149
+ if subtitles is None:
150
+ return Response("Not available", 400)
151
+ return JSONResponse(subtitles, 200)
152
+ except Exception as e:
153
+ logging.warn(e)
154
+ raise HTTPException(
155
+ status_code=500, detail="An error occurred while getting subtitles"
156
+ )
157
+
158
+
159
+ # if __name__ == "__main__":
160
+ # import uvicorn
161
+
162
+ # uvicorn.run(app, host="0.0.0.0", port=8000)