|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
from transformers import PreTrainedModel, AutoConfig |
|
|
|
|
|
class AIDetectorModel(nn.Module): |
|
def __init__(self): |
|
super(AIDetectorModel, self).__init__() |
|
|
|
self.base_model = models.efficientnet_v2_s(weights=None) |
|
|
|
|
|
self.base_model.classifier = nn.Sequential( |
|
nn.Linear(self.base_model.classifier[1].in_features, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(512, 2) |
|
) |
|
|
|
def forward(self, x): |
|
return self.base_model(x) |
|
|
|
|
|
class AIDetectorForImageClassification(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = AIDetectorModel() |
|
|
|
|
|
model_path = os.path.join(os.getcwd(), "best_model_improved.pth") |
|
try: |
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) |
|
print(f"Model loaded successfully from {model_path}") |
|
except Exception as e: |
|
print(f"Error with strict loading: {e}") |
|
print("Trying with strict=False...") |
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False) |
|
print("Model loaded with strict=False") |
|
|
|
def forward(self, pixel_values, labels=None, **kwargs): |
|
logits = self.model(pixel_values) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} |
|
|
|
|
|
def get_model(): |
|
config = AutoConfig.from_pretrained("./") |
|
model = AIDetectorForImageClassification(config) |
|
return model |
|
|