streamlit_demo / app.py
ldhldh's picture
Update app.py
ad6d6b2 verified
raw
history blame contribute delete
No virus
14.7 kB
import streamlit as st
import os, glob, pydub, time
from pytube import YouTube
import torch, torchaudio
import yaml
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchaudio.transforms as T
from src.models import models
from st_audiorec import st_audiorec
from pathlib import Path
import numpy as np
import subprocess
# λͺ…λ Ήμ–΄ μ‹€ν–‰
command = "apt-get update"
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# λͺ…λ Ήμ–΄ μ‹€ν–‰ κ²°κ³Ό 좜λ ₯
stdout, stderr = process.communicate()
print(stdout, stderr)
command = "apt-get install sox libsox-dev -y"
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# λͺ…λ Ήμ–΄ μ‹€ν–‰ κ²°κ³Ό 좜λ ₯
stdout, stderr = process.communicate()
print(stdout, stderr)
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client
import queue
def get_ice_servers():
"""Use Twilio's TURN server because Streamlit Community Cloud has changed
its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501
We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too,
but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501
See https://github.com/whitphx/streamlit-webrtc/issues/1213
"""
# Ref: https://www.twilio.com/docs/stun-turn/api
try:
account_sid = os.environ["TWILIO_ACCOUNT_SID"]
auth_token = os.environ["TWILIO_AUTH_TOKEN"]
except KeyError:
return [{"urls": ["stun:stun.l.google.com:19302"]}]
client = Client(account_sid, auth_token)
try:
token = client.tokens.create()
except TwilioRestException as e:
st.warning(
f"Error occurred while accessing Twilio API. Fallback to a free STUN server from Google. ({e})" # noqa: E501
)
return [{"urls": ["stun:stun.l.google.com:19302"]}]
return token.ice_servers
from streamlit_webrtc import webrtc_streamer
from streamlit_webrtc import WebRtcMode, webrtc_streamer
import subprocess
from pydub import AudioSegment
from pyannote.audio import Pipeline
import soundfile as sf
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Replace with your actual Hugging Face API token
huggingface_token = os.environ["key"]
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
use_auth_token=huggingface_token).to(device)
output_directory = '/MP3_Split'
def split_by_speaker(file_path, output_dir):
# Load the MP3 file
audio = AudioSegment.from_mp3(file_path)
# Convert audio to wav format (PyAnnote requires wav format)
wav_path = file_path.replace('.mp3', '.wav')
audio.export(wav_path, format="wav")
# Perform speaker diarization
diarization = pipeline(wav_path)
audio_0_2_4 = AudioSegment.silent(duration=5)
audio_1_3_5 = AudioSegment.silent(duration=5)
# Split the audio based on diarization results
base_filename = os.path.splitext(os.path.basename(file_path))[0]
for i, (segment, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
# Extract segment
start_time = segment.start * 1000 # PyAnnote uses seconds, pydub uses milliseconds
end_time = segment.end * 1000
audio_segment = audio[start_time:end_time]
# Save segment as a separate MP3 file
if i == 0:
audio_0_2_4 += audio_segment
elif i == 5:
audio_1_3_5 += audio_segment
os.makedirs(output_dir, exist_ok=True)
audio_0_2_4.export(os.path.join(output_dir, f"{0}_speaker.mp3"), format="mp3")
audio_1_3_5.export(os.path.join(output_dir, f"{1}_speaker.mp3"), format="mp3")
def clear_files_in_directory(directory):
if os.path.exists(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
clear_files_in_directory(file_path)
os.rmdir(file_path) # ν•˜μœ„ 디렉토리λ₯Ό λΉ„μš΄ ν›„ μ‚­μ œ
except Exception as e:
print(f'파일 {file_path} μ‚­μ œ 쀑 μ—λŸ¬ λ°œμƒ: {e}')
else:
print(f'디렉토리 {directory}κ°€ μ‘΄μž¬ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.')
# μ „μ²˜λ¦¬ ν•¨μˆ˜
SAMPLING_RATE = 16_000
def apply_preprocessing(
waveform,
sample_rate,
):
if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1:
waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)
# Stereo to mono
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform[:1, ...]
waveform, sample_rate = apply_trim(waveform, sample_rate)
waveform = apply_pad(waveform, 480_000)
return waveform, sample_rate
def resample_wave(waveform, sample_rate, target_sample_rate):
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
)
return waveform, sample_rate
def apply_trim(waveform, sample_rate):
(
waveform_trimmed,
sample_rate_trimmed,
) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, [["silence", "1", "0.2", "1%", "-1", "0.2", "1%"]])
if waveform_trimmed.size()[1] > 0:
waveform = waveform_trimmed
sample_rate = sample_rate_trimmed
return waveform, sample_rate
def apply_pad(waveform, cut):
"""Pad wave by repeating signal until `cut` length is achieved."""
waveform = waveform.squeeze(0)
waveform_len = waveform.shape[0]
if waveform_len >= cut:
return waveform[:cut]
# need to pad
num_repeats = int(cut / waveform_len) + 1
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
return padded_waveform
#
#
#
# λͺ¨λΈ μ„€μ • 및 λ‘œλ”©
device = "cuda" if torch.cuda.is_available() else "cpu"
with open('augmentation_ko_whisper_frontend_lcnn_mfcc.yaml', 'r') as f:
model_config = yaml.safe_load(f)
model_paths = model_config["checkpoint"]["path"]
model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"]
model = models.get_model(
model_name=model_name,
config=model_parameters,
device=device,
)
model.load_state_dict(torch.load(model_paths, map_location=torch.device('cpu')))
model = model.to(device)
model.eval()
# YouTube λΉ„λ””μ˜€ λ‹€μš΄λ‘œλ“œ 및 μ˜€λ””μ˜€ μΆ”μΆœ ν•¨μˆ˜
def download_youtube_audio(youtube_url, output_path="temp"):
yt = YouTube(youtube_url)
audio_stream = yt.streams.get_audio_only()
output_file = audio_stream.download(output_path=output_path)
title = audio_stream.default_filename
return output_file, title
# URLλ‘œλΆ€ν„° 예츑
def pred_from_url(youtube_url, segment_length=30):
global model
audio_path, title = download_youtube_audio(youtube_url)
print(f"- [{title}]에 λŒ€ν•΄ μ‹€ν–‰\n\n")
waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
waveform = torchaudio.functional.resample(waveform, orig_freq=48000, new_freq=SAMPLING_RATE)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
num_samples_per_segment = int(segment_length * sample_rate)
total_samples = waveform.size(1)
if total_samples <= num_samples_per_segment:
num_samples_per_segment = total_samples
num_segments = 1
else:
num_segments = total_samples // num_samples_per_segment
preds = []
print("μ˜€λ””μ˜€ chunk λΆ„ν•  수 :", num_segments)
for i in range(num_segments):
start_sample = i * num_samples_per_segment
end_sample = start_sample + num_samples_per_segment
segment = waveform[:, start_sample:end_sample]
segment, sample_rate = apply_preprocessing(segment, sample_rate)
pred = model(segment.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
preds.append(pred.item())
avg_pred = torch.tensor(preds).mean().item()
os.remove(audio_path)
output = ""
if int(avg_pred+0.5):
output = "fake"
else:
output = "real"
return f"""예츑:{output}
{(avg_pred*100):.2f}% ν™•λ₯ λ‘œ fakeμž…λ‹ˆλ‹€."""
# νŒŒμΌλ‘œλΆ€ν„° 예츑
def pred_from_file(file_path, segment_length=30):
global model
clear_files_in_directory(output_directory)
split_by_speaker(file_path, output_directory)
output = ""
for p in list(Path(output_directory).glob("*.mp3")):
waveform, sample_rate = torchaudio.load(p, normalize=True)
waveform = torchaudio.functional.resample(waveform, orig_freq=48000, new_freq=sample_rate)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
num_samples_per_segment = int(segment_length * sample_rate)
total_samples = waveform.size(1)
if total_samples <= num_samples_per_segment:
num_samples_per_segment = total_samples
num_segments = 1
else:
num_segments = total_samples // num_samples_per_segment
preds = []
print(f"ν™”μž {p.name}의 μ˜€λ””μ˜€ chunk λΆ„ν•  수 : {num_segments}")
for i in range(num_segments):
# 각 ꡬ간에 λŒ€ν•œ μΆ”λ‘  진행
start_sample = i * num_samples_per_segment
end_sample = start_sample + num_samples_per_segment
segment = waveform[:, start_sample:end_sample]
segment, sample_rate = apply_preprocessing(segment, sample_rate)
pred = model(segment.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
preds.append(pred.item())
avg_pred = torch.tensor(preds).mean().item()
output += f"ν™”μž {p.name} : {(avg_pred*100):.2f}% ν™•λ₯ λ‘œ fakeμž…λ‹ˆλ‹€.\n\n"
return output
def pred_from_realtime_audio(data):
global model
data = torch.tensor(data, dtype=torch.float32)
data = data.unsqueeze(0)
data = torchaudio.functional.resample(data, orig_freq=48000, new_freq=SAMPLING_RATE)
data = data / torch.max(torch.abs(data))
mean = torch.mean(data)
std = torch.std(data)
data = (data - mean) / std
data, sample_rate = apply_preprocessing(data, SAMPLING_RATE)
pred = model(torch.tensor(data).unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
return pred.item()
# Streamlit UI
st.title("DeepFake Detection Demo")
st.markdown("whisper-LCNN (using MLAAD, MAILABS, aihub 감성 및 λ°œν™”μŠ€νƒ€μΌ λ™μ‹œ κ³ λ € μŒμ„±ν•©μ„± 데이터, 자체 μˆ˜μ§‘ 및 μƒμ„±ν•œ KoAAD)")
st.markdown("github : https://github.com/ldh-Hoon/ko_deepfake-whisper-features")
tab1, tab2, tab3 = st.tabs(["YouTube URL", "파일 μ—…λ‘œλ“œ", "μ‹€μ‹œκ°„ μ˜€λ””μ˜€ μž…λ ₯"])
example_urls_fake = [
"https://youtu.be/ha3gfD7S0_E",
"https://youtu.be/5lmJ0Rhr-ec",
"https://youtu.be/q6ra0KDgVbg",
"https://youtu.be/hfmm1Oo6SSY?feature=shared"
]
example_urls_real = [
"https://youtu.be/54y1sYLZjqs",
"https://youtu.be/7qT0Stb3QNY",
]
if 'youtube_url' not in st.session_state:
st.session_state['youtube_url'] = ''
with tab1:
st.markdown("""example
>fake:
""")
for url in example_urls_fake:
if st.button(url, key=url):
st.session_state.youtube_url = url
st.markdown(""">real:
""")
for url in example_urls_real:
if st.button(url, key=url):
st.session_state.youtube_url = url
youtube_url = st.text_input("YouTube URL", value=st.session_state.youtube_url)
if youtube_url:
result = pred_from_url(youtube_url) # 여기에 pred_from_url ν•¨μˆ˜ μ •μ˜κ°€ ν•„μš”ν•©λ‹ˆλ‹€.
st.text_area("κ²°κ³Ό", value=result, height=150)
st.video(youtube_url)
with tab2:
file = st.file_uploader("μ˜€λ””μ˜€ 파일 μ—…λ‘œλ“œ", type=['mp3', 'wav'])
if file is not None and st.button("RUN 파일"):
# μž„μ‹œ 파일 μ €μž₯
with open(file.name, "wb") as f:
f.write(file.getbuffer())
result = pred_from_file(file.name)
st.text_area("κ²°κ³Ό", value=result, height=150)
os.remove(file.name) # μž„μ‹œ 파일 μ‚­μ œ
with tab3:
p = st.empty()
preds = []
fig, [ax_time, ax_freq] = plt.subplots(2, 1, gridspec_kw={"top": 1.5, "bottom": 0.2})
sound_window_len = 2000 # 5s
sound_window_buffer = None
webrtc_ctx = webrtc_streamer(
key="sendonly-audio",
mode=WebRtcMode.SENDONLY,
audio_receiver_size=1024,
rtc_configuration={"iceServers": get_ice_servers()},
media_stream_constraints={"audio": True},
)
while True:
if webrtc_ctx.audio_receiver:
try:
audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1)
except queue.Empty:
break
sound_chunk = pydub.AudioSegment.empty()
for audio_frame in audio_frames:
sound = pydub.AudioSegment(
data=audio_frame.to_ndarray().tobytes(),
sample_width=audio_frame.format.bytes,
frame_rate=audio_frame.sample_rate,
channels=len(audio_frame.layout.channels),
)
sound_chunk += sound
if len(sound_chunk) > 0:
if sound_window_buffer is None:
sound_window_buffer = pydub.AudioSegment.silent(
duration=sound_window_len
)
sound_window_buffer += sound_chunk
if len(sound_window_buffer) > sound_window_len:
sound_window_buffer = sound_window_buffer[-sound_window_len:]
if sound_window_buffer:
# Ref: https://own-search-and-study.xyz/2017/10/27/python%E3%82%92%E4%BD%BF%E3%81%A3%E3%81%A6%E9%9F%B3%E5%A3%B0%E3%83%87%E3%83%BC%E3%82%BF%E3%81%8B%E3%82%89%E3%82%B9%E3%83%9A%E3%82%AF%E3%83%88%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%A0%E3%82%92%E4%BD%9C/ # noqa
sound_window_buffer = sound_window_buffer.set_channels(1) # Stereo to mono
sample = np.array(sound_window_buffer.get_array_of_samples())
preds.append(pred_from_realtime_audio(sample))
if len(preds) > 100:
preds = preds[-100:]
p.write(f"pred : {np.mean(preds)*100:.2f}%")
else:
break