File size: 6,358 Bytes
384e020
 
 
 
 
0474f44
 
 
384e020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0474f44
 
384e020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0474f44
 
 
384e020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0474f44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from flask import Flask, request, render_template, redirect, url_for
import torch
import torchaudio
import numpy as np
import plotly.graph_objs as go
import os  # Import os for file operations
from model import BoundaryDetectionModel  # Assuming your model is defined here
from audio_dataset import pad_audio  # Assuming you have a function to pad audio

app = Flask(__name__)

# Load the pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BoundaryDetectionModel().to(device)
model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"])
model.eval()

def preprocess_audio(audio_path, sample_rate=16000, target_length=8):
    waveform, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
    waveform = pad_audio(waveform, sample_rate, target_length)
    return waveform.to(device)

def infer_single_audio(audio_tensor):
    with torch.no_grad():
        output = model(audio_tensor).squeeze(-1).cpu().numpy()
        prediction = (output > 0.5).astype(int)  # Binary prediction for fake/real frames
    return output, prediction

@app.route('/')
def index():
    return render_template('index.html')  # HTML page for file upload and results display

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return "No file uploaded", 400

    file = request.files['file']
    if file.filename == '':
        return "No selected file", 400

    file_path = "temp_audio.wav"  # Temporary file to store uploaded audio
    file.save(file_path)
    
    # Preprocess audio and perform inference
    audio_tensor = preprocess_audio(file_path)
    output, prediction = infer_single_audio(audio_tensor)
    
    # Flatten the prediction array to handle 2D structure
    prediction_flat = prediction.flatten()
    
    # Calculate total frames, fake frames, and fake percentage (formatted to 4 decimal places)
    total_frames = len(prediction_flat)
    fake_frame_count = int(np.sum(prediction_flat))
    fake_percentage = round((fake_frame_count / total_frames) * 100, 4)
    result_type = 'Fake' if fake_frame_count >= 5 else 'Real'
    
    # Check if audio is classified as real
    if result_type == 'Real':
        fake_frame_intervals = "No Frame"  # Set to "No Frame" if audio is real
    else:
        # Get precise fake frame timings with start and end times for fake frames
        fake_frame_intervals = get_fake_frame_intervals(prediction_flat, frame_duration=20)

    # Debug print to check intervals
    print("Fake Frame Intervals:", fake_frame_intervals)
    
    # Generate Plotly plot
    plot_html = plot_fake_frames_waveform(output, prediction_flat, audio_tensor.cpu().numpy(), fake_frame_intervals)

    # Render template with all results and plot
    return render_template('result.html', 
                           fake_percentage=fake_percentage,
                           result_type=result_type,
                           fake_frame_count=fake_frame_count,
                           total_frames=total_frames,
                           fake_frame_intervals=fake_frame_intervals,
                           plot_html=plot_html)

@app.route('/return', methods=['GET'])
def return_to_index():
    # Delete temporary files before returning to index
    try:
        os.remove("temp_audio.wav")  # Remove the temporary audio file
        # If you have any other temporary files (like plots), remove them here too.
        # Example: os.remove("temp_plot.html") if you save plots as HTML files.
    except OSError as e:
        print(f"Error deleting temporary files: {e}")
    
    return redirect(url_for('index'))  # Redirect back to the main page

def get_fake_frame_intervals(prediction, frame_duration=20):
    """
    Calculate start and end times in seconds for each consecutive fake frame interval.
    """
    intervals = []
    start_time = None

    for i, is_fake in enumerate(prediction):
        if is_fake == 1:
            if start_time is None:
                start_time = i * (frame_duration / 1000)  # Convert ms to seconds
        else:
            if start_time is not None:
                end_time = i * (frame_duration / 1000)  # End time of fake segment
                intervals.append((round(start_time, 4), round(end_time, 4)))
                start_time = None

    # Append last interval if it ended on the last frame
    if start_time is not None:
        end_time = len(prediction) * (frame_duration / 1000)  # Final end time calculation
        intervals.append((round(start_time, 4), round(end_time, 4)))
    
    return intervals

def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_intervals, frame_duration=20, sample_rate=16000):
    # Get actual audio duration from waveform for accurate x-axis scaling
    actual_duration = waveform.shape[1] / sample_rate
    num_samples = waveform.shape[1]  # Get number of samples from the actual waveform
    time = np.linspace(0, actual_duration, num_samples)
    
    # Plotly trace for the waveform with different colors for fake and real frames
    frame_length = int(sample_rate * frame_duration / 1000)  # Samples per frame
    
    traces = []
    for i in range(len(prediction_flat)):
        start = i * frame_length
        end = min(start + frame_length, num_samples)  # Ensure we do not exceed the samples
        color = 'rgba(255,0,0,0.8)' if prediction_flat[i] == 1 else 'rgba(0,128,0,0.5)'
        
        traces.append(go.Scatter(
            x=time[start:end],
            y=waveform[0][start:end],
            mode='lines',
            line=dict(color=color),
            showlegend=False
        ))

    # Full waveform view to show all fake and real segments
    min_time, max_time = 0, actual_duration

    # Layout settings for the plot
    layout = go.Layout(
        title="Audio Waveform with Fake Frames Highlighted",
        xaxis=dict(title="Time (seconds)", range=[min_time, max_time]),
        yaxis=dict(title="Amplitude"),
        autosize=True,
        template="plotly_white"
    )

    fig = go.Figure(data=traces, layout=layout)
    
    # Convert Plotly figure to HTML
    plot_html = fig.to_html(full_html=False)
    return plot_html

if __name__ == '__main__':
    app.run()