streamlit_demo / app.py
ldhldh's picture
Update app.py
c328712 verified
raw
history blame
7.26 kB
import streamlit as st
import os
from pytube import YouTube
import torch, torchaudio
import yaml
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchaudio.transforms as T
from src.models import models
from st_audiorec import st_audiorec
import subprocess
# ๋ช…๋ น์–ด ์‹คํ–‰
command = "apt-get update"
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# ๋ช…๋ น์–ด ์‹คํ–‰ ๊ฒฐ๊ณผ ์ถœ๋ ฅ
stdout, stderr = process.communicate()
print(stdout, stderr)
command = "apt-get install sox libsox-dev -y"
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# ๋ช…๋ น์–ด ์‹คํ–‰ ๊ฒฐ๊ณผ ์ถœ๋ ฅ
stdout, stderr = process.communicate()
print(stdout, stderr)
# ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
SAMPLING_RATE = 16_000
def apply_preprocessing(
waveform,
sample_rate,
):
if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1:
waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)
# Stereo to mono
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform[:1, ...]
waveform, sample_rate = apply_trim(waveform, sample_rate)
waveform = apply_pad(waveform, 480_000)
return waveform, sample_rate
def resample_wave(waveform, sample_rate, target_sample_rate):
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
)
return waveform, sample_rate
def apply_trim(waveform, sample_rate):
(
waveform_trimmed,
sample_rate_trimmed,
) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, [["silence", "1", "0.2", "1%", "-1", "0.2", "1%"]])
if waveform_trimmed.size()[1] > 0:
waveform = waveform_trimmed
sample_rate = sample_rate_trimmed
return waveform, sample_rate
def apply_pad(waveform, cut):
"""Pad wave by repeating signal until `cut` length is achieved."""
waveform = waveform.squeeze(0)
waveform_len = waveform.shape[0]
if waveform_len >= cut:
return waveform[:cut]
# need to pad
num_repeats = int(cut / waveform_len) + 1
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
return padded_waveform
#
#
#
# ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋”ฉ
device = "cuda" if torch.cuda.is_available() else "cpu"
with open('augmentation_ko_whisper_frontend_lcnn_mfcc.yaml', 'r') as f:
model_config = yaml.safe_load(f)
model_paths = model_config["checkpoint"]["path"]
model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"]
model = models.get_model(
model_name=model_name,
config=model_parameters,
device=device,
)
model.load_state_dict(torch.load(model_paths, map_location=torch.device('cpu')))
model = model.to(device)
model.eval()
# YouTube ๋น„๋””์˜ค ๋‹ค์šด๋กœ๋“œ ๋ฐ ์˜ค๋””์˜ค ์ถ”์ถœ ํ•จ์ˆ˜
def download_youtube_audio(youtube_url, output_path="temp"):
yt = YouTube(youtube_url)
audio_stream = yt.streams.get_audio_only()
output_file = audio_stream.download(output_path=output_path)
title = audio_stream.default_filename
return output_file, title
# URL๋กœ๋ถ€ํ„ฐ ์˜ˆ์ธก
def pred_from_url(youtube_url, segment_length=30):
global model
audio_path, title = download_youtube_audio(youtube_url)
waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
num_samples_per_segment = int(segment_length * sample_rate)
total_samples = waveform.size(1)
if total_samples <= num_samples_per_segment:
num_samples_per_segment = total_samples
num_segments = 1
else:
num_segments = total_samples // num_samples_per_segment
preds = []
print(num_segments)
for i in range(num_segments):
# ๊ฐ ๊ตฌ๊ฐ„์— ๋Œ€ํ•œ ์ถ”๋ก  ์ง„ํ–‰
start_sample = i * num_samples_per_segment
end_sample = start_sample + num_samples_per_segment
segment = waveform[:, start_sample:end_sample]
segment, sample_rate = apply_preprocessing(segment, sample_rate)
pred = model(segment.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
preds.append(pred.item())
# ํ‰๊ท  ์˜ˆ์ธก๊ฐ’ ๊ณ„์‚ฐ
avg_pred = torch.tensor(preds).mean().item()
os.remove(audio_path)
return (f"{title}\n\n{(100-avg_pred*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค.")
# ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ ์˜ˆ์ธก
def pred_from_file(file_path):
global model
waveform, sample_rate = torchaudio.load(file_path, normalize=True)
waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
pred = model(waveform.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
return f"{(100-pred[0][0]*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค."
def pred_from_realtime_audio(duration=5, sample_rate=SAMPLING_RATE):
global model
def record_audio(duration, sample_rate):
print("Recording...")
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1)
sd.wait() # Wait until recording is finished
print("Recording finished.")
return recording
waveform = record_audio(duration, sample_rate)
waveform = torch.tensor(waveform).transpose(0, 1)
waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
pred = model(waveform.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
return f"{(100-pred[0][0]*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค."
# Streamlit UI
st.title("DeepFake Detection Demo")
st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub ๊ฐ์„ฑ ๋ฐ ๋ฐœํ™”์Šคํƒ€์ผ ๋™์‹œ ๊ณ ๋ ค ์Œ์„ฑํ•ฉ์„ฑ ๋ฐ์ดํ„ฐ, ์ž์ฒด ์ˆ˜์ง‘ ๋ฐ ์ƒ์„ฑํ•œ KoAAD)")
st.markdown("github : https://github.com/ldh-Hoon/ko_deepfake-whisper-features")
tab1, tab2, tab3 = st.tabs(["YouTube URL", "ํŒŒ์ผ ์—…๋กœ๋“œ", "์‹ค์‹œ๊ฐ„ ์˜ค๋””์˜ค ์ž…๋ ฅ"])
with tab1:
youtube_url = st.text_input("YouTube URL")
st.markdown("""example
>fake:
https://youtu.be/ha3gfD7S0_E
https://youtu.be/5lmJ0Rhr-ec
https://youtu.be/q6ra0KDgVbg
https://youtu.be/hfmm1Oo6SSY?feature=shared
https://youtu.be/8QcbRM0Zq_c?feature=shared
>real:
https://youtu.be/54y1sYLZjqs
https://youtu.be/7qT0Stb3QNY
""")
if st.button("RUN URL"):
result = pred_from_url(youtube_url)
st.text_area("๊ฒฐ๊ณผ", value=result, height=150)
with tab2:
file = st.file_uploader("์˜ค๋””์˜ค ํŒŒ์ผ ์—…๋กœ๋“œ", type=['mp3', 'wav'])
if file is not None and st.button("RUN ํŒŒ์ผ"):
# ์ž„์‹œ ํŒŒ์ผ ์ €์žฅ
with open(file.name, "wb") as f:
f.write(file.getbuffer())
result = pred_from_file(file.name)
st.text_area("๊ฒฐ๊ณผ", value=result, height=150)
os.remove(file.name) # ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
with tab3:
file = 'temp.wav'
wav_audio_data = st_audiorec()
if wav_audio_data is not None:
with st.spinner('๋…น์Œ๋œ ์Œ์„ฑ์„ ์ €์žฅ์ค‘...'):
with open(file, 'wb') as f:
f.write(wav_audio_data)
result = pred_from_file(file)
st.text_area("๊ฒฐ๊ณผ", value=result, height=150)
os.remove(file) # ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ