import streamlit as st import os from pytube import YouTube import torch, torchaudio import yaml import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms import torchaudio.transforms as T from src.models import models from st_audiorec import st_audiorec import subprocess # 명령어 실행 command = "apt-get update" process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # 명령어 실행 결과 출력 stdout, stderr = process.communicate() print(stdout, stderr) command = "apt-get install sox libsox-dev -y" process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # 명령어 실행 결과 출력 stdout, stderr = process.communicate() print(stdout, stderr) # 전처리 함수 SAMPLING_RATE = 16_000 def apply_preprocessing( waveform, sample_rate, ): if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1: waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE) # Stereo to mono if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform[:1, ...] waveform, sample_rate = apply_trim(waveform, sample_rate) waveform = apply_pad(waveform, 480_000) return waveform, sample_rate def resample_wave(waveform, sample_rate, target_sample_rate): waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( waveform, sample_rate, [["rate", f"{target_sample_rate}"]] ) return waveform, sample_rate def apply_trim(waveform, sample_rate): ( waveform_trimmed, sample_rate_trimmed, ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, [["silence", "1", "0.2", "1%", "-1", "0.2", "1%"]]) if waveform_trimmed.size()[1] > 0: waveform = waveform_trimmed sample_rate = sample_rate_trimmed return waveform, sample_rate def apply_pad(waveform, cut): """Pad wave by repeating signal until `cut` length is achieved.""" waveform = waveform.squeeze(0) waveform_len = waveform.shape[0] if waveform_len >= cut: return waveform[:cut] # need to pad num_repeats = int(cut / waveform_len) + 1 padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0] return padded_waveform # # # # 모델 설정 및 로딩 device = "cuda" if torch.cuda.is_available() else "cpu" with open('augmentation_ko_whisper_frontend_lcnn_mfcc.yaml', 'r') as f: model_config = yaml.safe_load(f) model_paths = model_config["checkpoint"]["path"] model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"] model = models.get_model( model_name=model_name, config=model_parameters, device=device, ) model.load_state_dict(torch.load(model_paths, map_location=torch.device('cpu'))) model = model.to(device) model.eval() # YouTube 비디오 다운로드 및 오디오 추출 함수 def download_youtube_audio(youtube_url, output_path="temp"): yt = YouTube(youtube_url) audio_stream = yt.streams.get_audio_only() output_file = audio_stream.download(output_path=output_path) title = audio_stream.default_filename return output_file, title # URL로부터 예측 def pred_from_url(youtube_url, segment_length=30): global model audio_path, title = download_youtube_audio(youtube_url) waveform, sample_rate = torchaudio.load(audio_path, normalize=True) if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True) num_samples_per_segment = int(segment_length * sample_rate) total_samples = waveform.size(1) if total_samples <= num_samples_per_segment: num_samples_per_segment = total_samples num_segments = 1 else: num_segments = total_samples // num_samples_per_segment preds = [] print(num_segments) for i in range(num_segments): # 각 구간에 대한 추론 진행 start_sample = i * num_samples_per_segment end_sample = start_sample + num_samples_per_segment segment = waveform[:, start_sample:end_sample] segment, sample_rate = apply_preprocessing(segment, sample_rate) pred = model(segment.unsqueeze(0).to(device)) pred = torch.sigmoid(pred) preds.append(pred.item()) # 평균 예측값 계산 avg_pred = torch.tensor(preds).mean().item() os.remove(audio_path) return (f"{title}\n\n{(100-avg_pred*100):.2f}% 확률로 fake입니다.") # 파일로부터 예측 def pred_from_file(file_path): global model waveform, sample_rate = torchaudio.load(file_path, normalize=True) waveform, sample_rate = apply_preprocessing(waveform, sample_rate) pred = model(waveform.unsqueeze(0).to(device)) pred = torch.sigmoid(pred) return f"{(100-pred[0][0]*100):.2f}% 확률로 fake입니다." def pred_from_realtime_audio(duration=5, sample_rate=SAMPLING_RATE): global model def record_audio(duration, sample_rate): print("Recording...") recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1) sd.wait() # Wait until recording is finished print("Recording finished.") return recording waveform = record_audio(duration, sample_rate) waveform = torch.tensor(waveform).transpose(0, 1) waveform, sample_rate = apply_preprocessing(waveform, sample_rate) pred = model(waveform.unsqueeze(0).to(device)) pred = torch.sigmoid(pred) return f"{(100-pred[0][0]*100):.2f}% 확률로 fake입니다." # Streamlit UI st.title("DeepFake Detection Demo") st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub 감성 및 발화스타일 동시 고려 음성합성 데이터, 자체 수집 및 생성한 KoAAD)") st.markdown("github : https://github.com/ldh-Hoon/ko_deepfake-whisper-features") tab1, tab2, tab3 = st.tabs(["YouTube URL", "파일 업로드", "실시간 오디오 입력"]) with tab1: youtube_url = st.text_input("YouTube URL") st.markdown("""example >fake: https://youtu.be/ha3gfD7S0_E https://youtu.be/5lmJ0Rhr-ec https://youtu.be/q6ra0KDgVbg https://youtu.be/hfmm1Oo6SSY?feature=shared https://youtu.be/8QcbRM0Zq_c?feature=shared >real: https://youtu.be/54y1sYLZjqs https://youtu.be/7qT0Stb3QNY """) if st.button("RUN URL"): result = pred_from_url(youtube_url) st.text_area("결과", value=result, height=150) with tab2: file = st.file_uploader("오디오 파일 업로드", type=['mp3', 'wav']) if file is not None and st.button("RUN 파일"): # 임시 파일 저장 with open(file.name, "wb") as f: f.write(file.getbuffer()) result = pred_from_file(file.name) st.text_area("결과", value=result, height=150) os.remove(file.name) # 임시 파일 삭제 with tab3: file = 'temp.wav' wav_audio_data = st_audiorec() if wav_audio_data is not None: with st.spinner('녹음된 음성을 저장중...'): with open(file, 'wb') as f: f.write(wav_audio_data) result = pred_from_file(file) st.text_area("결과", value=result, height=150) os.remove(file) # 임시 파일 삭제