Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
from pytube import YouTube | |
import torch, torchaudio | |
import yaml | |
import matplotlib.pyplot as plt | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
import torchaudio.transforms as T | |
from src.models import models | |
from st_audiorec import st_audiorec | |
import subprocess | |
# ๋ช ๋ น์ด ์คํ | |
command = "apt-get update" | |
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
# ๋ช ๋ น์ด ์คํ ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
stdout, stderr = process.communicate() | |
print(stdout, stderr) | |
command = "apt-get install sox libsox-dev -y" | |
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
# ๋ช ๋ น์ด ์คํ ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
stdout, stderr = process.communicate() | |
print(stdout, stderr) | |
# ์ ์ฒ๋ฆฌ ํจ์ | |
SAMPLING_RATE = 16_000 | |
def apply_preprocessing( | |
waveform, | |
sample_rate, | |
): | |
if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1: | |
waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE) | |
# Stereo to mono | |
if waveform.dim() > 1 and waveform.shape[0] > 1: | |
waveform = waveform[:1, ...] | |
waveform, sample_rate = apply_trim(waveform, sample_rate) | |
waveform = apply_pad(waveform, 480_000) | |
return waveform, sample_rate | |
def resample_wave(waveform, sample_rate, target_sample_rate): | |
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( | |
waveform, sample_rate, [["rate", f"{target_sample_rate}"]] | |
) | |
return waveform, sample_rate | |
def apply_trim(waveform, sample_rate): | |
( | |
waveform_trimmed, | |
sample_rate_trimmed, | |
) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, [["silence", "1", "0.2", "1%", "-1", "0.2", "1%"]]) | |
if waveform_trimmed.size()[1] > 0: | |
waveform = waveform_trimmed | |
sample_rate = sample_rate_trimmed | |
return waveform, sample_rate | |
def apply_pad(waveform, cut): | |
"""Pad wave by repeating signal until `cut` length is achieved.""" | |
waveform = waveform.squeeze(0) | |
waveform_len = waveform.shape[0] | |
if waveform_len >= cut: | |
return waveform[:cut] | |
# need to pad | |
num_repeats = int(cut / waveform_len) + 1 | |
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0] | |
return padded_waveform | |
# | |
# | |
# | |
# ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋ฉ | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
with open('augmentation_ko_whisper_frontend_lcnn_mfcc.yaml', 'r') as f: | |
model_config = yaml.safe_load(f) | |
model_paths = model_config["checkpoint"]["path"] | |
model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"] | |
model = models.get_model( | |
model_name=model_name, | |
config=model_parameters, | |
device=device, | |
) | |
model.load_state_dict(torch.load(model_paths, map_location=torch.device('cpu'))) | |
model = model.to(device) | |
model.eval() | |
# YouTube ๋น๋์ค ๋ค์ด๋ก๋ ๋ฐ ์ค๋์ค ์ถ์ถ ํจ์ | |
def download_youtube_audio(youtube_url, output_path="temp"): | |
yt = YouTube(youtube_url) | |
audio_stream = yt.streams.get_audio_only() | |
output_file = audio_stream.download(output_path=output_path) | |
title = audio_stream.default_filename | |
return output_file, title | |
# URL๋ก๋ถํฐ ์์ธก | |
def pred_from_url(youtube_url, segment_length=30): | |
global model | |
audio_path, title = download_youtube_audio(youtube_url) | |
waveform, sample_rate = torchaudio.load(audio_path, normalize=True) | |
if waveform.size(0) > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
num_samples_per_segment = int(segment_length * sample_rate) | |
total_samples = waveform.size(1) | |
if total_samples <= num_samples_per_segment: | |
num_samples_per_segment = total_samples | |
num_segments = 1 | |
else: | |
num_segments = total_samples // num_samples_per_segment | |
preds = [] | |
print(num_segments) | |
for i in range(num_segments): | |
# ๊ฐ ๊ตฌ๊ฐ์ ๋ํ ์ถ๋ก ์งํ | |
start_sample = i * num_samples_per_segment | |
end_sample = start_sample + num_samples_per_segment | |
segment = waveform[:, start_sample:end_sample] | |
segment, sample_rate = apply_preprocessing(segment, sample_rate) | |
pred = model(segment.unsqueeze(0).to(device)) | |
pred = torch.sigmoid(pred) | |
preds.append(pred.item()) | |
# ํ๊ท ์์ธก๊ฐ ๊ณ์ฐ | |
avg_pred = torch.tensor(preds).mean().item() | |
os.remove(audio_path) | |
return (f"{title}\n\n{(100-avg_pred*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค.") | |
# ํ์ผ๋ก๋ถํฐ ์์ธก | |
def pred_from_file(file_path): | |
global model | |
waveform, sample_rate = torchaudio.load(file_path, normalize=True) | |
waveform, sample_rate = apply_preprocessing(waveform, sample_rate) | |
pred = model(waveform.unsqueeze(0).to(device)) | |
pred = torch.sigmoid(pred) | |
return f"{(100-pred[0][0]*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค." | |
def pred_from_realtime_audio(duration=5, sample_rate=SAMPLING_RATE): | |
global model | |
def record_audio(duration, sample_rate): | |
print("Recording...") | |
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1) | |
sd.wait() # Wait until recording is finished | |
print("Recording finished.") | |
return recording | |
waveform = record_audio(duration, sample_rate) | |
waveform = torch.tensor(waveform).transpose(0, 1) | |
waveform, sample_rate = apply_preprocessing(waveform, sample_rate) | |
pred = model(waveform.unsqueeze(0).to(device)) | |
pred = torch.sigmoid(pred) | |
return f"{(100-pred[0][0]*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค." | |
# Streamlit UI | |
st.title("DeepFake Detection Demo") | |
st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub ๊ฐ์ฑ ๋ฐ ๋ฐํ์คํ์ผ ๋์ ๊ณ ๋ ค ์์ฑํฉ์ฑ ๋ฐ์ดํฐ, ์์ฒด ์์ง ๋ฐ ์์ฑํ KoAAD)") | |
st.markdown("github : https://github.com/ldh-Hoon/ko_deepfake-whisper-features") | |
tab1, tab2, tab3 = st.tabs(["YouTube URL", "ํ์ผ ์ ๋ก๋", "์ค์๊ฐ ์ค๋์ค ์ ๋ ฅ"]) | |
with tab1: | |
youtube_url = st.text_input("YouTube URL") | |
st.markdown("""example | |
>fake: | |
https://youtu.be/ha3gfD7S0_E | |
https://youtu.be/5lmJ0Rhr-ec | |
https://youtu.be/q6ra0KDgVbg | |
https://youtu.be/hfmm1Oo6SSY?feature=shared | |
https://youtu.be/8QcbRM0Zq_c?feature=shared | |
>real: | |
https://youtu.be/54y1sYLZjqs | |
https://youtu.be/7qT0Stb3QNY | |
""") | |
if st.button("RUN URL"): | |
result = pred_from_url(youtube_url) | |
st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
with tab2: | |
file = st.file_uploader("์ค๋์ค ํ์ผ ์ ๋ก๋", type=['mp3', 'wav']) | |
if file is not None and st.button("RUN ํ์ผ"): | |
# ์์ ํ์ผ ์ ์ฅ | |
with open(file.name, "wb") as f: | |
f.write(file.getbuffer()) | |
result = pred_from_file(file.name) | |
st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
os.remove(file.name) # ์์ ํ์ผ ์ญ์ | |
with tab3: | |
file = 'temp.wav' | |
wav_audio_data = st_audiorec() | |
if wav_audio_data is not None: | |
with st.spinner('๋ น์๋ ์์ฑ์ ์ ์ฅ์ค...'): | |
with open(file, 'wb') as f: | |
f.write(wav_audio_data) | |
result = pred_from_file(file) | |
st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
os.remove(file) # ์์ ํ์ผ ์ญ์ |