Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os, glob, pydub, time | |
| from pytube import YouTube | |
| import torch, torchaudio | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| import torchaudio.transforms as T | |
| from src.models import models | |
| from st_audiorec import st_audiorec | |
| from pathlib import Path | |
| import numpy as np | |
| import subprocess | |
| # λͺ λ Ήμ΄ μ€ν | |
| command = "apt-get update" | |
| process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # λͺ λ Ήμ΄ μ€ν κ²°κ³Ό μΆλ ₯ | |
| stdout, stderr = process.communicate() | |
| print(stdout, stderr) | |
| command = "apt-get install sox libsox-dev -y" | |
| process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # λͺ λ Ήμ΄ μ€ν κ²°κ³Ό μΆλ ₯ | |
| stdout, stderr = process.communicate() | |
| print(stdout, stderr) | |
| from twilio.base.exceptions import TwilioRestException | |
| from twilio.rest import Client | |
| import queue | |
| def get_ice_servers(): | |
| """Use Twilio's TURN server because Streamlit Community Cloud has changed | |
| its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501 | |
| We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too, | |
| but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501 | |
| See https://github.com/whitphx/streamlit-webrtc/issues/1213 | |
| """ | |
| # Ref: https://www.twilio.com/docs/stun-turn/api | |
| try: | |
| account_sid = os.environ["TWILIO_ACCOUNT_SID"] | |
| auth_token = os.environ["TWILIO_AUTH_TOKEN"] | |
| except KeyError: | |
| return [{"urls": ["stun:stun.l.google.com:19302"]}] | |
| client = Client(account_sid, auth_token) | |
| try: | |
| token = client.tokens.create() | |
| except TwilioRestException as e: | |
| st.warning( | |
| f"Error occurred while accessing Twilio API. Fallback to a free STUN server from Google. ({e})" # noqa: E501 | |
| ) | |
| return [{"urls": ["stun:stun.l.google.com:19302"]}] | |
| return token.ice_servers | |
| from streamlit_webrtc import webrtc_streamer | |
| from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
| import subprocess | |
| from pydub import AudioSegment | |
| from pyannote.audio import Pipeline | |
| import soundfile as sf | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| # Replace with your actual Hugging Face API token | |
| huggingface_token = os.environ["key"] | |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", | |
| use_auth_token=huggingface_token).to(device) | |
| output_directory = '/MP3_Split' | |
| def split_by_speaker(file_path, output_dir): | |
| # Load the MP3 file | |
| audio = AudioSegment.from_mp3(file_path) | |
| # Convert audio to wav format (PyAnnote requires wav format) | |
| wav_path = file_path.replace('.mp3', '.wav') | |
| audio.export(wav_path, format="wav") | |
| # Perform speaker diarization | |
| diarization = pipeline(wav_path) | |
| audio_0_2_4 = AudioSegment.silent(duration=5) | |
| audio_1_3_5 = AudioSegment.silent(duration=5) | |
| # Split the audio based on diarization results | |
| base_filename = os.path.splitext(os.path.basename(file_path))[0] | |
| for i, (segment, _, speaker) in enumerate(diarization.itertracks(yield_label=True)): | |
| # Extract segment | |
| start_time = segment.start * 1000 # PyAnnote uses seconds, pydub uses milliseconds | |
| end_time = segment.end * 1000 | |
| audio_segment = audio[start_time:end_time] | |
| # Save segment as a separate MP3 file | |
| if i == 0: | |
| audio_0_2_4 += audio_segment | |
| elif i == 5: | |
| audio_1_3_5 += audio_segment | |
| os.makedirs(output_dir, exist_ok=True) | |
| audio_0_2_4.export(os.path.join(output_dir, f"{0}_speaker.mp3"), format="mp3") | |
| audio_1_3_5.export(os.path.join(output_dir, f"{1}_speaker.mp3"), format="mp3") | |
| def clear_files_in_directory(directory): | |
| if os.path.exists(directory): | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| clear_files_in_directory(file_path) | |
| os.rmdir(file_path) # νμ λλ ν 리λ₯Ό λΉμ΄ ν μμ | |
| except Exception as e: | |
| print(f'νμΌ {file_path} μμ μ€ μλ¬ λ°μ: {e}') | |
| else: | |
| print(f'λλ ν 리 {directory}κ° μ‘΄μ¬νμ§ μμ΅λλ€.') | |
| # μ μ²λ¦¬ ν¨μ | |
| SAMPLING_RATE = 16_000 | |
| def apply_preprocessing( | |
| waveform, | |
| sample_rate, | |
| ): | |
| if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1: | |
| waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE) | |
| # Stereo to mono | |
| if waveform.dim() > 1 and waveform.shape[0] > 1: | |
| waveform = waveform[:1, ...] | |
| waveform, sample_rate = apply_trim(waveform, sample_rate) | |
| waveform = apply_pad(waveform, 480_000) | |
| return waveform, sample_rate | |
| def resample_wave(waveform, sample_rate, target_sample_rate): | |
| waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( | |
| waveform, sample_rate, [["rate", f"{target_sample_rate}"]] | |
| ) | |
| return waveform, sample_rate | |
| def apply_trim(waveform, sample_rate): | |
| ( | |
| waveform_trimmed, | |
| sample_rate_trimmed, | |
| ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, [["silence", "1", "0.2", "1%", "-1", "0.2", "1%"]]) | |
| if waveform_trimmed.size()[1] > 0: | |
| waveform = waveform_trimmed | |
| sample_rate = sample_rate_trimmed | |
| return waveform, sample_rate | |
| def apply_pad(waveform, cut): | |
| """Pad wave by repeating signal until `cut` length is achieved.""" | |
| waveform = waveform.squeeze(0) | |
| waveform_len = waveform.shape[0] | |
| if waveform_len >= cut: | |
| return waveform[:cut] | |
| # need to pad | |
| num_repeats = int(cut / waveform_len) + 1 | |
| padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0] | |
| return padded_waveform | |
| # | |
| # | |
| # | |
| # λͺ¨λΈ μ€μ λ° λ‘λ© | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with open('augmentation_ko_whisper_frontend_lcnn_mfcc.yaml', 'r') as f: | |
| model_config = yaml.safe_load(f) | |
| model_paths = model_config["checkpoint"]["path"] | |
| model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"] | |
| model = models.get_model( | |
| model_name=model_name, | |
| config=model_parameters, | |
| device=device, | |
| ) | |
| model.load_state_dict(torch.load(model_paths, map_location=torch.device('cpu'))) | |
| model = model.to(device) | |
| model.eval() | |
| # YouTube λΉλμ€ λ€μ΄λ‘λ λ° μ€λμ€ μΆμΆ ν¨μ | |
| def download_youtube_audio(youtube_url, output_path="temp"): | |
| yt = YouTube(youtube_url) | |
| audio_stream = yt.streams.get_audio_only() | |
| output_file = audio_stream.download(output_path=output_path) | |
| title = audio_stream.default_filename | |
| return output_file, title | |
| # URLλ‘λΆν° μμΈ‘ | |
| def pred_from_url(youtube_url, segment_length=30): | |
| global model | |
| audio_path, title = download_youtube_audio(youtube_url) | |
| print(f"- [{title}]μ λν΄ μ€ν\n\n") | |
| waveform, sample_rate = torchaudio.load(audio_path, normalize=True) | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=48000, new_freq=SAMPLING_RATE) | |
| if waveform.size(0) > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| num_samples_per_segment = int(segment_length * sample_rate) | |
| total_samples = waveform.size(1) | |
| if total_samples <= num_samples_per_segment: | |
| num_samples_per_segment = total_samples | |
| num_segments = 1 | |
| else: | |
| num_segments = total_samples // num_samples_per_segment | |
| preds = [] | |
| print("μ€λμ€ chunk λΆν μ :", num_segments) | |
| for i in range(num_segments): | |
| start_sample = i * num_samples_per_segment | |
| end_sample = start_sample + num_samples_per_segment | |
| segment = waveform[:, start_sample:end_sample] | |
| segment, sample_rate = apply_preprocessing(segment, sample_rate) | |
| pred = model(segment.unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| preds.append(pred.item()) | |
| avg_pred = torch.tensor(preds).mean().item() | |
| os.remove(audio_path) | |
| output = "" | |
| if int(avg_pred+0.5): | |
| output = "fake" | |
| else: | |
| output = "real" | |
| return f"""μμΈ‘:{output} | |
| {(avg_pred*100):.2f}% νλ₯ λ‘ fakeμ λλ€.""" | |
| # νμΌλ‘λΆν° μμΈ‘ | |
| def pred_from_file(file_path, segment_length=30): | |
| global model | |
| clear_files_in_directory(output_directory) | |
| split_by_speaker(file_path, output_directory) | |
| output = "" | |
| for p in list(Path(output_directory).glob("*.mp3")): | |
| waveform, sample_rate = torchaudio.load(p, normalize=True) | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=48000, new_freq=sample_rate) | |
| if waveform.size(0) > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| num_samples_per_segment = int(segment_length * sample_rate) | |
| total_samples = waveform.size(1) | |
| if total_samples <= num_samples_per_segment: | |
| num_samples_per_segment = total_samples | |
| num_segments = 1 | |
| else: | |
| num_segments = total_samples // num_samples_per_segment | |
| preds = [] | |
| print(f"νμ {p.name}μ μ€λμ€ chunk λΆν μ : {num_segments}") | |
| for i in range(num_segments): | |
| # κ° κ΅¬κ°μ λν μΆλ‘ μ§ν | |
| start_sample = i * num_samples_per_segment | |
| end_sample = start_sample + num_samples_per_segment | |
| segment = waveform[:, start_sample:end_sample] | |
| segment, sample_rate = apply_preprocessing(segment, sample_rate) | |
| pred = model(segment.unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| preds.append(pred.item()) | |
| avg_pred = torch.tensor(preds).mean().item() | |
| output += f"νμ {p.name} : {(avg_pred*100):.2f}% νλ₯ λ‘ fakeμ λλ€.\n\n" | |
| return output | |
| def pred_from_realtime_audio(data): | |
| global model | |
| data = torch.tensor(data, dtype=torch.float32) | |
| data = data.unsqueeze(0) | |
| data = torchaudio.functional.resample(data, orig_freq=48000, new_freq=SAMPLING_RATE) | |
| data = data / torch.max(torch.abs(data)) | |
| mean = torch.mean(data) | |
| std = torch.std(data) | |
| data = (data - mean) / std | |
| data, sample_rate = apply_preprocessing(data, SAMPLING_RATE) | |
| pred = model(torch.tensor(data).unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| return pred.item() | |
| # Streamlit UI | |
| st.title("DeepFake Detection Demo") | |
| st.markdown("whisper-LCNN (using MLAAD, MAILABS, aihub κ°μ± λ° λ°νμ€νμΌ λμ κ³ λ € μμ±ν©μ± λ°μ΄ν°, μ체 μμ§ λ° μμ±ν KoAAD)") | |
| st.markdown("github : https://github.com/ldh-Hoon/ko_deepfake-whisper-features") | |
| tab1, tab2, tab3 = st.tabs(["YouTube URL", "νμΌ μ λ‘λ", "μ€μκ° μ€λμ€ μ λ ₯"]) | |
| example_urls_fake = [ | |
| "https://youtu.be/ha3gfD7S0_E", | |
| "https://youtu.be/5lmJ0Rhr-ec", | |
| "https://youtu.be/q6ra0KDgVbg", | |
| "https://youtu.be/hfmm1Oo6SSY?feature=shared" | |
| ] | |
| example_urls_real = [ | |
| "https://youtu.be/54y1sYLZjqs", | |
| "https://youtu.be/7qT0Stb3QNY", | |
| ] | |
| if 'youtube_url' not in st.session_state: | |
| st.session_state['youtube_url'] = '' | |
| with tab1: | |
| st.markdown("""example | |
| >fake: | |
| """) | |
| for url in example_urls_fake: | |
| if st.button(url, key=url): | |
| st.session_state.youtube_url = url | |
| st.markdown(""">real: | |
| """) | |
| for url in example_urls_real: | |
| if st.button(url, key=url): | |
| st.session_state.youtube_url = url | |
| youtube_url = st.text_input("YouTube URL", value=st.session_state.youtube_url) | |
| if youtube_url: | |
| result = pred_from_url(youtube_url) # μ¬κΈ°μ pred_from_url ν¨μ μ μκ° νμν©λλ€. | |
| st.text_area("κ²°κ³Ό", value=result, height=150) | |
| st.video(youtube_url) | |
| with tab2: | |
| file = st.file_uploader("μ€λμ€ νμΌ μ λ‘λ", type=['mp3', 'wav']) | |
| if file is not None and st.button("RUN νμΌ"): | |
| # μμ νμΌ μ μ₯ | |
| with open(file.name, "wb") as f: | |
| f.write(file.getbuffer()) | |
| result = pred_from_file(file.name) | |
| st.text_area("κ²°κ³Ό", value=result, height=150) | |
| os.remove(file.name) # μμ νμΌ μμ | |
| with tab3: | |
| p = st.empty() | |
| preds = [] | |
| fig, [ax_time, ax_freq] = plt.subplots(2, 1, gridspec_kw={"top": 1.5, "bottom": 0.2}) | |
| sound_window_len = 2000 # 5s | |
| sound_window_buffer = None | |
| webrtc_ctx = webrtc_streamer( | |
| key="sendonly-audio", | |
| mode=WebRtcMode.SENDONLY, | |
| audio_receiver_size=1024, | |
| rtc_configuration={"iceServers": get_ice_servers()}, | |
| media_stream_constraints={"audio": True}, | |
| ) | |
| while True: | |
| if webrtc_ctx.audio_receiver: | |
| try: | |
| audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1) | |
| except queue.Empty: | |
| break | |
| sound_chunk = pydub.AudioSegment.empty() | |
| for audio_frame in audio_frames: | |
| sound = pydub.AudioSegment( | |
| data=audio_frame.to_ndarray().tobytes(), | |
| sample_width=audio_frame.format.bytes, | |
| frame_rate=audio_frame.sample_rate, | |
| channels=len(audio_frame.layout.channels), | |
| ) | |
| sound_chunk += sound | |
| if len(sound_chunk) > 0: | |
| if sound_window_buffer is None: | |
| sound_window_buffer = pydub.AudioSegment.silent( | |
| duration=sound_window_len | |
| ) | |
| sound_window_buffer += sound_chunk | |
| if len(sound_window_buffer) > sound_window_len: | |
| sound_window_buffer = sound_window_buffer[-sound_window_len:] | |
| if sound_window_buffer: | |
| # Ref: https://own-search-and-study.xyz/2017/10/27/python%E3%82%92%E4%BD%BF%E3%81%A3%E3%81%A6%E9%9F%B3%E5%A3%B0%E3%83%87%E3%83%BC%E3%82%BF%E3%81%8B%E3%82%89%E3%82%B9%E3%83%9A%E3%82%AF%E3%83%88%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%A0%E3%82%92%E4%BD%9C/ # noqa | |
| sound_window_buffer = sound_window_buffer.set_channels(1) # Stereo to mono | |
| sample = np.array(sound_window_buffer.get_array_of_samples()) | |
| preds.append(pred_from_realtime_audio(sample)) | |
| if len(preds) > 100: | |
| preds = preds[-100:] | |
| p.write(f"pred : {np.mean(preds)*100:.2f}%") | |
| else: | |
| break |