Spaces:
Sleeping
Sleeping
| from flask import Flask, request, render_template, redirect, url_for | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import plotly.graph_objs as go | |
| import os # Import os for file operations | |
| from model import BoundaryDetectionModel # Assuming your model is defined here | |
| from audio_dataset import pad_audio # Assuming you have a function to pad audio | |
| app = Flask(__name__) | |
| # Load the pre-trained model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = BoundaryDetectionModel().to(device) | |
| model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"]) | |
| model.eval() | |
| def preprocess_audio(audio_path, sample_rate=16000, target_length=8): | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != sample_rate: | |
| waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) | |
| waveform = pad_audio(waveform, sample_rate, target_length) | |
| return waveform.to(device) | |
| def infer_single_audio(audio_tensor): | |
| with torch.no_grad(): | |
| output = model(audio_tensor).squeeze(-1).cpu().numpy() | |
| prediction = (output > 0.5).astype(int) # Binary prediction for fake/real frames | |
| return output, prediction | |
| def index(): | |
| return render_template('index.html') # HTML page for file upload and results display | |
| def predict(): | |
| if 'file' not in request.files: | |
| return "No file uploaded", 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return "No selected file", 400 | |
| file_path = "temp_audio.wav" # Temporary file to store uploaded audio | |
| file.save(file_path) | |
| # Preprocess audio and perform inference | |
| audio_tensor = preprocess_audio(file_path) | |
| output, prediction = infer_single_audio(audio_tensor) | |
| # Flatten the prediction array to handle 2D structure | |
| prediction_flat = prediction.flatten() | |
| # Calculate total frames, fake frames, and fake percentage (formatted to 4 decimal places) | |
| total_frames = len(prediction_flat) | |
| fake_frame_count = int(np.sum(prediction_flat)) | |
| fake_percentage = round((fake_frame_count / total_frames) * 100, 4) | |
| result_type = 'Fake' if fake_frame_count >= 5 else 'Real' | |
| # Check if audio is classified as real | |
| if result_type == 'Real': | |
| fake_frame_intervals = "No Frame" # Set to "No Frame" if audio is real | |
| else: | |
| # Get precise fake frame timings with start and end times for fake frames | |
| fake_frame_intervals = get_fake_frame_intervals(prediction_flat, frame_duration=20) | |
| # Debug print to check intervals | |
| print("Fake Frame Intervals:", fake_frame_intervals) | |
| # Generate Plotly plot | |
| plot_html = plot_fake_frames_waveform(output, prediction_flat, audio_tensor.cpu().numpy(), fake_frame_intervals) | |
| # Render template with all results and plot | |
| return render_template('result.html', | |
| fake_percentage=fake_percentage, | |
| result_type=result_type, | |
| fake_frame_count=fake_frame_count, | |
| total_frames=total_frames, | |
| fake_frame_intervals=fake_frame_intervals, | |
| plot_html=plot_html) | |
| def return_to_index(): | |
| # Delete temporary files before returning to index | |
| try: | |
| os.remove("temp_audio.wav") # Remove the temporary audio file | |
| # If you have any other temporary files (like plots), remove them here too. | |
| # Example: os.remove("temp_plot.html") if you save plots as HTML files. | |
| except OSError as e: | |
| print(f"Error deleting temporary files: {e}") | |
| return redirect(url_for('index')) # Redirect back to the main page | |
| def get_fake_frame_intervals(prediction, frame_duration=20): | |
| """ | |
| Calculate start and end times in seconds for each consecutive fake frame interval. | |
| """ | |
| intervals = [] | |
| start_time = None | |
| for i, is_fake in enumerate(prediction): | |
| if is_fake == 1: | |
| if start_time is None: | |
| start_time = i * (frame_duration / 1000) # Convert ms to seconds | |
| else: | |
| if start_time is not None: | |
| end_time = i * (frame_duration / 1000) # End time of fake segment | |
| intervals.append((round(start_time, 4), round(end_time, 4))) | |
| start_time = None | |
| # Append last interval if it ended on the last frame | |
| if start_time is not None: | |
| end_time = len(prediction) * (frame_duration / 1000) # Final end time calculation | |
| intervals.append((round(start_time, 4), round(end_time, 4))) | |
| return intervals | |
| def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_intervals, frame_duration=20, sample_rate=16000): | |
| # Get actual audio duration from waveform for accurate x-axis scaling | |
| actual_duration = waveform.shape[1] / sample_rate | |
| num_samples = waveform.shape[1] # Get number of samples from the actual waveform | |
| time = np.linspace(0, actual_duration, num_samples) | |
| # Plotly trace for the waveform with different colors for fake and real frames | |
| frame_length = int(sample_rate * frame_duration / 1000) # Samples per frame | |
| traces = [] | |
| for i in range(len(prediction_flat)): | |
| start = i * frame_length | |
| end = min(start + frame_length, num_samples) # Ensure we do not exceed the samples | |
| color = 'rgba(255,0,0,0.8)' if prediction_flat[i] == 1 else 'rgba(0,128,0,0.5)' | |
| traces.append(go.Scatter( | |
| x=time[start:end], | |
| y=waveform[0][start:end], | |
| mode='lines', | |
| line=dict(color=color), | |
| showlegend=False | |
| )) | |
| # Full waveform view to show all fake and real segments | |
| min_time, max_time = 0, actual_duration | |
| # Layout settings for the plot | |
| layout = go.Layout( | |
| title="Audio Waveform with Fake Frames Highlighted", | |
| xaxis=dict(title="Time (seconds)", range=[min_time, max_time]), | |
| yaxis=dict(title="Amplitude"), | |
| autosize=True, | |
| template="plotly_white" | |
| ) | |
| fig = go.Figure(data=traces, layout=layout) | |
| # Convert Plotly figure to HTML | |
| plot_html = fig.to_html(full_html=False) | |
| return plot_html | |
| if __name__ == '__main__': | |
| app.run() |