Spaces:
Sleeping
Sleeping
File size: 6,358 Bytes
384e020 0474f44 384e020 0474f44 384e020 0474f44 384e020 0474f44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from flask import Flask, request, render_template, redirect, url_for
import torch
import torchaudio
import numpy as np
import plotly.graph_objs as go
import os # Import os for file operations
from model import BoundaryDetectionModel # Assuming your model is defined here
from audio_dataset import pad_audio # Assuming you have a function to pad audio
app = Flask(__name__)
# Load the pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BoundaryDetectionModel().to(device)
model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"])
model.eval()
def preprocess_audio(audio_path, sample_rate=16000, target_length=8):
waveform, sr = torchaudio.load(audio_path)
if sr != sample_rate:
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
waveform = pad_audio(waveform, sample_rate, target_length)
return waveform.to(device)
def infer_single_audio(audio_tensor):
with torch.no_grad():
output = model(audio_tensor).squeeze(-1).cpu().numpy()
prediction = (output > 0.5).astype(int) # Binary prediction for fake/real frames
return output, prediction
@app.route('/')
def index():
return render_template('index.html') # HTML page for file upload and results display
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return "No file uploaded", 400
file = request.files['file']
if file.filename == '':
return "No selected file", 400
file_path = "temp_audio.wav" # Temporary file to store uploaded audio
file.save(file_path)
# Preprocess audio and perform inference
audio_tensor = preprocess_audio(file_path)
output, prediction = infer_single_audio(audio_tensor)
# Flatten the prediction array to handle 2D structure
prediction_flat = prediction.flatten()
# Calculate total frames, fake frames, and fake percentage (formatted to 4 decimal places)
total_frames = len(prediction_flat)
fake_frame_count = int(np.sum(prediction_flat))
fake_percentage = round((fake_frame_count / total_frames) * 100, 4)
result_type = 'Fake' if fake_frame_count >= 5 else 'Real'
# Check if audio is classified as real
if result_type == 'Real':
fake_frame_intervals = "No Frame" # Set to "No Frame" if audio is real
else:
# Get precise fake frame timings with start and end times for fake frames
fake_frame_intervals = get_fake_frame_intervals(prediction_flat, frame_duration=20)
# Debug print to check intervals
print("Fake Frame Intervals:", fake_frame_intervals)
# Generate Plotly plot
plot_html = plot_fake_frames_waveform(output, prediction_flat, audio_tensor.cpu().numpy(), fake_frame_intervals)
# Render template with all results and plot
return render_template('result.html',
fake_percentage=fake_percentage,
result_type=result_type,
fake_frame_count=fake_frame_count,
total_frames=total_frames,
fake_frame_intervals=fake_frame_intervals,
plot_html=plot_html)
@app.route('/return', methods=['GET'])
def return_to_index():
# Delete temporary files before returning to index
try:
os.remove("temp_audio.wav") # Remove the temporary audio file
# If you have any other temporary files (like plots), remove them here too.
# Example: os.remove("temp_plot.html") if you save plots as HTML files.
except OSError as e:
print(f"Error deleting temporary files: {e}")
return redirect(url_for('index')) # Redirect back to the main page
def get_fake_frame_intervals(prediction, frame_duration=20):
"""
Calculate start and end times in seconds for each consecutive fake frame interval.
"""
intervals = []
start_time = None
for i, is_fake in enumerate(prediction):
if is_fake == 1:
if start_time is None:
start_time = i * (frame_duration / 1000) # Convert ms to seconds
else:
if start_time is not None:
end_time = i * (frame_duration / 1000) # End time of fake segment
intervals.append((round(start_time, 4), round(end_time, 4)))
start_time = None
# Append last interval if it ended on the last frame
if start_time is not None:
end_time = len(prediction) * (frame_duration / 1000) # Final end time calculation
intervals.append((round(start_time, 4), round(end_time, 4)))
return intervals
def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_intervals, frame_duration=20, sample_rate=16000):
# Get actual audio duration from waveform for accurate x-axis scaling
actual_duration = waveform.shape[1] / sample_rate
num_samples = waveform.shape[1] # Get number of samples from the actual waveform
time = np.linspace(0, actual_duration, num_samples)
# Plotly trace for the waveform with different colors for fake and real frames
frame_length = int(sample_rate * frame_duration / 1000) # Samples per frame
traces = []
for i in range(len(prediction_flat)):
start = i * frame_length
end = min(start + frame_length, num_samples) # Ensure we do not exceed the samples
color = 'rgba(255,0,0,0.8)' if prediction_flat[i] == 1 else 'rgba(0,128,0,0.5)'
traces.append(go.Scatter(
x=time[start:end],
y=waveform[0][start:end],
mode='lines',
line=dict(color=color),
showlegend=False
))
# Full waveform view to show all fake and real segments
min_time, max_time = 0, actual_duration
# Layout settings for the plot
layout = go.Layout(
title="Audio Waveform with Fake Frames Highlighted",
xaxis=dict(title="Time (seconds)", range=[min_time, max_time]),
yaxis=dict(title="Amplitude"),
autosize=True,
template="plotly_white"
)
fig = go.Figure(data=traces, layout=layout)
# Convert Plotly figure to HTML
plot_html = fig.to_html(full_html=False)
return plot_html
if __name__ == '__main__':
app.run() |