Max005's picture
Update main.py
9a76b1f
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import os
import torchaudio
import torch.nn.functional as F
import torch
from transformers import AutoProcessor, AutoModelForAudioClassification, pipeline
from pathlib import Path
app_dir = Path(__file__).parent
# Deepfake model setup
deepfake_model_path = app_dir / "Deepfake" / "model"
deepfake_processor = AutoProcessor.from_pretrained(deepfake_model_path)
deepfake_model = AutoModelForAudioClassification.from_pretrained(
pretrained_model_name_or_path=deepfake_model_path,
local_files_only=True,
)
def prepare_audio(file_path, sampling_rate=16000, duration=10):
waveform, original_sampling_rate = torchaudio.load(file_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if original_sampling_rate != sampling_rate:
resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)
waveform = resampler(waveform)
chunk_size = sampling_rate * duration
audio_chunks = []
for start in range(0, waveform.shape[1], chunk_size):
chunk = waveform[:, start:start + chunk_size]
if chunk.shape[1] < chunk_size:
padding = chunk_size - chunk.shape[1]
chunk = torch.nn.functional.pad(chunk, (0, padding))
audio_chunks.append(chunk.squeeze().numpy())
return audio_chunks
def predict_audio(file_path):
audio_chunks = prepare_audio(file_path)
predictions = []
confidences = []
for chunk in audio_chunks:
inputs = deepfake_processor(
chunk, sampling_rate=16000, return_tensors="pt", padding=True
)
with torch.no_grad():
outputs = deepfake_model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=1)
confidence, predicted_class = torch.max(probabilities, dim=1)
predictions.append(predicted_class.item())
confidences.append(confidence.item())
aggregated_prediction_id = max(set(predictions), key=predictions.count)
predicted_label = deepfake_model.config.id2label[aggregated_prediction_id]
average_confidence = sum(confidences) / len(confidences)
return {
"predicted_label": predicted_label,
"average_confidence": average_confidence
}
# ScamText model setup
scamtext_pipe = pipeline("text-classification", model="phishbot/ScamLLM")
# Input model for scam text inference
class TextInput(BaseModel):
input: str
# Initialize FastAPI
app = FastAPI()
@app.post("/deepfake/infer")
async def deepfake_infer(file: UploadFile = File(...)):
temp_file_path = f"temp_{file.filename}"
with open(temp_file_path, "wb") as temp_file:
temp_file.write(await file.read())
try:
predictions = predict_audio(temp_file_path)
finally:
os.remove(temp_file_path)
return predictions
@app.post("/scamtext/infer")
async def scamtext_infer(data: TextInput):
predictions = scamtext_pipe(data.input)
return predictions
@app.get("/deepfake/health")
async def deepfake_health():
return {
"message": "ok",
"Sound": str(torchaudio.list_audio_backends())
}
@app.get("/scamtext/health")
async def scamtext_health():
return {"message": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)