ProgramSkripsi / app_utils.py
Yuuki0's picture
first commit
e0c75d6
# app_utils.py
# This file will contain the refactored core logic for training and prediction.
import os
import time
import pickle
import torch
import gradio as gr
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from model.config import load_config
from model.genconvit_ed import GenConViTED
from model.genconvit_vae import GenConViTVAE
from dataset.loader import load_data, load_checkpoint
from model.pred_func import set_result, load_genconvit, df_face, pred_vid, real_or_fake
# Load configuration
config = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_available_weights(weight_dir="weight"):
"""Scans the weight directory for .pth files."""
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
weights = [f for f in os.listdir(weight_dir) if f.endswith(".pth")]
return weights if weights else ["No weights found"]
def count_files_in_subdirs(directory):
"""Counts files in the 'real' and 'fake' subdirectories of a given directory."""
real_path = os.path.join(directory, 'real')
fake_path = os.path.join(directory, 'fake')
real_count = 0
if os.path.exists(real_path) and os.path.isdir(real_path):
real_count = len([name for name in os.listdir(real_path) if os.path.isfile(os.path.join(real_path, name))])
fake_count = 0
if os.path.exists(fake_path) and os.path.isdir(fake_path):
fake_count = len([name for name in os.listdir(fake_path) if os.path.isfile(os.path.join(fake_path, name))])
return f"Real: {real_count}, Fake: {fake_count}"
def get_dataset_counts():
"""Gets the file counts for train, validation, and test sets."""
train_counts = count_files_in_subdirs('train')
valid_counts = count_files_in_subdirs('valid')
test_counts = count_files_in_subdirs('test')
return train_counts, valid_counts, test_counts
def train_model_gradio(model_variant, ed_pretrained, vae_pretrained, epochs, batch_size, run_test, use_fp16, progress=gr.Progress()):
"""Refactored training function for Gradio UI."""
dir_path = './'
if not (os.path.exists('train') and os.path.exists('valid')):
yield "Error: 'train' and 'valid' directories not found. Please create them and populate them with 'real' and 'fake' subdirectories."
return
yield "Loading data..."
progress(0, desc="Loading data...")
try:
dataloaders, dataset_sizes = load_data(dir_path, int(batch_size))
yield "Data loaded."
except Exception as e:
yield f"Error loading data: {e}. Please ensure the dataset is structured correctly."
return
models = []
optimizers = []
if model_variant in ["AE", "AE & VAE"]:
yield "Initializing AE model..."
model_ed = GenConViTED(config)
optimizer_ed = optim.Adam(model_ed.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"]))
if ed_pretrained and ed_pretrained != "No weights found":
try:
model_ed, optimizer_ed, _, _ = load_checkpoint(model_ed, optimizer_ed, filename=os.path.join("weight", ed_pretrained))
except Exception as e:
yield f"Error loading ED checkpoint: {e}"
models.append(("ed", model_ed, optimizer_ed))
if model_variant in ["VAE", "AE & VAE"]:
yield "Initializing VAE model..."
model_vae = GenConViTVAE(config)
optimizer_vae = optim.Adam(model_vae.parameters(), lr=float(config["learning_rate"]), weight_decay=float(config["weight_decay"]))
if vae_pretrained and vae_pretrained != "No weights found":
try:
model_vae, optimizer_vae, _, _ = load_checkpoint(model_vae, optimizer_vae, filename=os.path.join("weight", vae_pretrained))
except Exception as e:
yield f"Error loading VAE checkpoint: {e}"
models.append(("vae", model_vae, optimizer_vae))
for mod, model, optimizer in models:
yield f"Starting training for {mod.upper()} model..."
criterion = nn.CrossEntropyLoss().to(device)
mse = nn.MSELoss()
scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
model.to(device)
train_loss, train_acc, valid_loss, valid_acc = [], [], [], []
train_func, valid_func = None, None
if mod == 'ed':
from train.train_ed import train as train_func_ed, valid as valid_func_ed
train_func, valid_func = train_func_ed, valid_func_ed
else:
from train.train_vae import train as train_func_vae, valid as valid_func_vae
train_func, valid_func = train_func_vae, valid_func_vae
for epoch in range(int(epochs)):
epoch_desc = f"Epoch {epoch+1}/{int(epochs)} ({mod.upper()})"
progress(epoch / int(epochs), desc=epoch_desc)
yield f"{epoch_desc} - Training..."
epoch_loss, epoch_acc = 0,0
try:
train_loss, train_acc, epoch_loss = train_func(model, device, dataloaders["train"], criterion, optimizer, epoch, train_loss, train_acc, mse)
except Exception as e:
yield f"Error during training: {e}"
break
yield f"{epoch_desc} - Validation..."
try:
valid_loss, valid_acc = valid_func(model, device, dataloaders["validation"], criterion, epoch, valid_loss, valid_acc, mse)
yield f"Epoch {epoch+1} complete for {mod.upper()}. Validation Loss: {valid_loss[-1]:.4f}, Validation Acc: {valid_acc[-1]:.4f}"
except Exception as e:
yield f"Error during validation: {e}"
break
scheduler.step()
yield f"Training complete for {mod.upper()}. Saving model..."
progress(1, desc=f"Saving {mod.upper()} model...")
file_path = os.path.join("weight", f'genconvit_{mod}_{time.strftime("%b_%d_%Y_%H_%M_%S", time.localtime())}')
with open(f"{file_path}.pkl", "wb") as f:
pickle.dump([train_loss, train_acc, valid_loss, valid_acc], f)
state = {
"epoch": epochs, "state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(), "min_loss": epoch_loss,
}
weight_filename = f"{file_path}.pth"
torch.save(state, weight_filename)
yield f"Model saved to {weight_filename}"
if run_test:
yield f"Running test for {mod.upper()} model..."
# test() function from train.py needs to be refactored to be callable here
pass
yield "All training processes finished."
def predict_video_gradio(video_path, ed_weight, vae_weight, num_frames, use_fp16, progress=gr.Progress()):
"""Refactored prediction function for Gradio UI."""
if not video_path:
return "Please upload a video.", "", "", ""
net_type = None
ed_weight_path, vae_weight_path = None, None
if ed_weight and ed_weight != "No weights found":
ed_weight_path = os.path.join("weight", ed_weight)
if vae_weight and vae_weight != "No weights found":
vae_weight_path = os.path.join("weight", vae_weight)
if ed_weight_path and vae_weight_path:
net_type = 'genconvit'
elif ed_weight_path:
net_type = 'ed'
elif vae_weight_path:
net_type = 'vae'
else:
return "Status: Error", "Please select at least one model weight.", ""
yield "Status: Loading model...", "", ""
progress(0.1, desc="Loading model...")
try:
model = load_genconvit(config, net_type, ed_weight_path, vae_weight_path, use_fp16)
except Exception as e:
return f"Status: Error loading model", f"Details: {e}", ""
yield "Status: Model loaded. Extracting faces...", "", ""
progress(0.3, desc="Extracting faces...")
try:
faces = df_face(video_path, int(num_frames))
if len(faces) == 0:
return "Status: Error", "No faces detected in the video.", ""
except Exception as e:
return "Status: Error during face extraction", f"Details: {e}. Is dlib installed correctly?", ""
yield f"Status: {len(faces)} face(s) detected. Running prediction...", "", ""
progress(0.8, desc="Running prediction...")
try:
y, y_val = pred_vid(faces, model)
label = real_or_fake(y)
score = y_val if label == "REAL" else 1 - y_val
confidence_str = f"{score*100:.2f}%"
progress(1, desc="Prediction complete")
return f"Status: Prediction complete.", label, confidence_str
except Exception as e:
return "Status: Error during prediction", f"Details: {e}", ""