|
|
import json |
|
|
import torch |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
from fastapi import FastAPI, UploadFile, HTTPException, File |
|
|
import nest_asyncio |
|
|
import uvicorn |
|
|
from model_utils import ( |
|
|
Model, |
|
|
) |
|
|
import os |
|
|
import soundfile as sf |
|
|
import io |
|
|
import tempfile |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import FileResponse |
|
|
from typing import List, Union |
|
|
from calculate_modules import compute_eer |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="Web"), name="static") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def home(): |
|
|
return FileResponse("Web/index.html") |
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def load_config(config_path): |
|
|
try: |
|
|
with open(config_path, "r") as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
print(f"Erreur lors du chargement de la configuration : {e}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Erreur lors du chargement de la configuration: {e}", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, d_args): |
|
|
models_list=[] |
|
|
for i in range(len(checkpoint_path)): |
|
|
model = Model(d_args) |
|
|
try: |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path[i], map_location=torch.device("cpu")) |
|
|
model.load_state_dict(checkpoint) |
|
|
print(f"Model_{i} loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading model_{i}: {e}") |
|
|
raise |
|
|
model.eval() |
|
|
models_list.append(model) |
|
|
return models_list |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_audio(audio_path, sample_rate=16000): |
|
|
try: |
|
|
print(f"Chargement de l'audio: {audio_path}") |
|
|
waveform, sr = torchaudio.load(audio_path) |
|
|
print(f"Audio chargé: {audio_path}, Taux d'échantillonnage: {sr}") |
|
|
if sr != sample_rate: |
|
|
resample_transform = torchaudio.transforms.Resample( |
|
|
orig_freq=sr, new_freq=sample_rate |
|
|
) |
|
|
waveform = resample_transform(waveform) |
|
|
if waveform.size(0) > 1: |
|
|
waveform = torch.mean( |
|
|
waveform, dim=0, keepdim=True |
|
|
) |
|
|
return waveform |
|
|
except Exception as e: |
|
|
print(f"Erreur dans le prétraitement audio : {e}") |
|
|
raise HTTPException( |
|
|
status_code=500, detail=f"Erreur dans le prétraitement de l'audio: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
def infer(model_list, waveform, freq_aug=False): |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
probabilities_sum=None |
|
|
for model in model_list: |
|
|
last_hidden, output = model(waveform, Freq_aug=freq_aug) |
|
|
print("Sortie du modèle:", output) |
|
|
if output is None: |
|
|
raise ValueError("La sortie du modèle est nulle.") |
|
|
probabilities = F.softmax(output, dim=1) |
|
|
if probabilities_sum is None: |
|
|
probabilities_sum=probabilities |
|
|
else: |
|
|
probabilities_sum+=np.array(probabilities[0].tolist()) |
|
|
|
|
|
probabilities_sum=probabilities_sum/len(model_list) |
|
|
predicted_label = torch.argmax(probabilities_sum, dim=1).item() |
|
|
confidence = probabilities_sum[ |
|
|
0 |
|
|
].tolist() |
|
|
max_confidence = 1 - max(confidence) |
|
|
return ( |
|
|
predicted_label, |
|
|
max_confidence, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Erreur pendant l'inférence : {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
config_path = "./AASIST_ASVspoof5_Exp4_CL.conf" |
|
|
config = load_config(config_path) |
|
|
d_args = config["model_config"] |
|
|
checkpoint_path = [f"./S{i+1}_best.pth" for i in range(4)] |
|
|
model_list = load_model(checkpoint_path, d_args) |
|
|
|
|
|
|
|
|
@app.post("/predict/") |
|
|
async def predict(files: List[UploadFile] = File(...)): |
|
|
""" |
|
|
Endpoint to handle batch inference for multiple audio files. |
|
|
Accepts a list of audio files and returns inference results for each file. |
|
|
""" |
|
|
responses = [] |
|
|
bonafide_scores = [] |
|
|
spoof_scores = [] |
|
|
|
|
|
for file in files: |
|
|
try: |
|
|
logger.info(f"Processing file: {file.filename}") |
|
|
|
|
|
|
|
|
if not file.filename.endswith((".wav", ".flac")): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Invalid file format. Only .wav and .flac files are allowed.", |
|
|
) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
|
|
temp_audio_path = temp_audio.name |
|
|
temp_audio.write(await file.read()) |
|
|
|
|
|
|
|
|
waveform = preprocess_audio(temp_audio_path) |
|
|
|
|
|
rms = torch.sqrt(torch.mean(waveform**2)).item() |
|
|
max_eng = torch.sqrt(torch.max(waveform**2)).item() |
|
|
volume = "normal" |
|
|
|
|
|
if rms < 0.01: |
|
|
volume = "silent" |
|
|
elif rms < 0.05 and max_eng<0.3: |
|
|
volume = "augmented" |
|
|
waveform = waveform * (1/(3*rms)) |
|
|
|
|
|
logger.info(f"Volume state for {file.filename}: rms={rms},max_eng={max_eng} > {volume}") |
|
|
|
|
|
|
|
|
label, confidence = infer(model_list, waveform) |
|
|
|
|
|
|
|
|
if label == 0: |
|
|
bonafide_scores.append(confidence) |
|
|
else: |
|
|
spoof_scores.append(confidence) |
|
|
|
|
|
|
|
|
response = { |
|
|
"filename": file.filename, |
|
|
"label": "Genuine" if label == 0 else "Spoof", |
|
|
"confidence": confidence, |
|
|
"status": "success", |
|
|
"volume":volume |
|
|
} |
|
|
responses.append(response) |
|
|
|
|
|
|
|
|
os.unlink(temp_audio_path) |
|
|
|
|
|
except Exception as e: |
|
|
responses.append( |
|
|
{"filename": file.filename, "error": str(e), "status": "failed"} |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"Bonafide scores: {bonafide_scores}") |
|
|
logger.info(f"Spoof scores: {spoof_scores}") |
|
|
|
|
|
|
|
|
if bonafide_scores and spoof_scores: |
|
|
eer, _, _, _ = compute_eer(np.array(bonafide_scores), np.array(spoof_scores)) |
|
|
eer_percentage = eer * 100 |
|
|
responses.append({"EER": f"{eer_percentage:.2f}%"}) |
|
|
logger.info(f"Calculated EER: {eer_percentage:.2f}%") |
|
|
else: |
|
|
logger.info("Not enough data to calculate EER.") |
|
|
|
|
|
return responses |
|
|
|
|
|
|
|
|
|
|
|
nest_asyncio.apply() |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
proxy_headers=True, |
|
|
forwarded_allow_ips="*" |
|
|
) |
|
|
|