AIME / app.py
lcjln's picture
Update app.py
df8b5ec verified
import os
import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import srt
from datetime import timedelta
# ์˜ค๋””์˜ค ํŒŒ์ผ์„ 5์ดˆ ๊ฐ„๊ฒฉ์œผ๋กœ ๋‚˜๋ˆ„๋Š” ํ•จ์ˆ˜
def split_audio(audio, sr, segment_duration=5):
segments = []
for i in range(0, len(audio), int(segment_duration * sr)):
segment = audio[i:i + int(segment_duration * sr)]
segments.append(segment)
return segments
# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
@st.cache_resource
def load_model():
model = WhisperForConditionalGeneration.from_pretrained("lcjln/AIME_Project_The_Final")
processor = WhisperProcessor.from_pretrained("lcjln/AIME_The_Final")
return model, processor
model, processor = load_model()
# Streamlit ์›น ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ธํ„ฐํŽ˜์ด์Šค
st.title("Whisper ์ž๋ง‰ ์ƒ์„ฑ๊ธฐ")
# ์—ฌ๋Ÿฌ WAV ํŒŒ์ผ ์—…๋กœ๋“œ
uploaded_files = st.file_uploader("์—ฌ๊ธฐ์— WAV ํŒŒ์ผ๋“ค์„ ๋“œ๋ž˜๊ทธ ์•ค ๋“œ๋กญ ํ•˜์„ธ์š”", type=["wav"], accept_multiple_files=True)
# ํŒŒ์ผ ๋ชฉ๋ก์„ ๋ณด์—ฌ์คŒ
if uploaded_files:
st.write("์—…๋กœ๋“œ๋œ ํŒŒ์ผ ๋ชฉ๋ก:")
for uploaded_file in uploaded_files:
st.write(uploaded_file.name)
# ์‹คํ–‰ ๋ฒ„ํŠผ
if st.button("์‹คํ–‰"):
combined_subs = []
last_end_time = timedelta(0)
subtitle_index = 1
for uploaded_file in uploaded_files:
st.write(f"์ฒ˜๋ฆฌ ์ค‘: {uploaded_file.name}")
# ์ง„ํ–‰๋ฐ” ์ดˆ๊ธฐํ™”
progress_bar = st.progress(0)
# WAV ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ์ฒ˜๋ฆฌ
st.write("์˜ค๋””์˜ค ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
audio, sr = librosa.load(uploaded_file, sr=16000)
progress_bar.progress(50)
# Whisper ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
st.write("๋ชจ๋ธ์„ ํ†ตํ•ด ์ž๋ง‰์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
segments = split_audio(audio, sr, segment_duration=5)
for i, segment in enumerate(segments):
inputs = processor(segment, return_tensors="pt", sampling_rate=16000)
with torch.no_grad():
outputs = model.generate(inputs["input_features"], max_length=2048, return_dict_in_generate=True, output_scores=True)
# ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
transcription = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
# ์‹ ๋ขฐ๋„ ์ ์ˆ˜ ๊ณ„์‚ฐ (์ถ”๊ฐ€์ ์ธ ์‹ ๋ขฐ๋„ ํ•„ํ„ฐ๋ง ์ ์šฉ)
avg_logit_score = torch.mean(outputs.scores[-1]).item()
# ์‹ ๋ขฐ๋„ ์ ์ˆ˜๊ฐ€ ๋‚ฎ๊ฑฐ๋‚˜ ํ…์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ ๋ฌด์‹œ
if transcription and avg_logit_score > -5.0:
segment_duration = librosa.get_duration(y=segment, sr=sr)
end_time = last_end_time + timedelta(seconds=segment_duration)
combined_subs.append(
srt.Subtitle(
index=subtitle_index,
start=last_end_time,
end=end_time,
content=transcription
)
)
last_end_time = end_time
subtitle_index += 1
progress_bar.progress(100)
st.success(f"{uploaded_file.name}์˜ ์ž๋ง‰์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")
# ๋ชจ๋“  ์ž๋ง‰์„ ํ•˜๋‚˜์˜ SRT ํŒŒ์ผ๋กœ ์ €์žฅ
st.write("์ตœ์ข… SRT ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
srt_content = srt.compose(combined_subs)
final_srt_file_path = "combined_output.srt"
with open(final_srt_file_path, "w", encoding="utf-8") as f:
f.write(srt_content)
st.success("์ตœ์ข… SRT ํŒŒ์ผ์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")
# ์ตœ์ข… SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ
with open(final_srt_file_path, "rb") as srt_file:
st.download_button(label="SRT ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ", data=srt_file, file_name=final_srt_file_path, mime="text/srt")