|
from typing import Dict, List, Any |
|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
MODEL_PATH = 'website_classifier.pth' |
|
|
|
|
|
def process_image(image): |
|
|
|
img = image.convert("RGB") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
img_t = transform(img) |
|
|
|
|
|
img_u = torch.unsqueeze(img_t, 0) |
|
|
|
return img_u |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
self.model = torchvision.models.resnet18(pretrained=True) |
|
num_ftrs = self.model.fc.in_features |
|
self.model.fc = nn.Linear(num_ftrs, 3) |
|
self.transform = transforms.Compose( |
|
[transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
self.model.load_state_dict(torch.load(MODEL_PATH)) |
|
self.processor = process_image |
|
self.classes = ['forum', 'general', 'marketplace'] |
|
self.classe_to_idx = {'forum': 0, 'general': 1, 'marketplace': 2} |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
image = data.pop("inputs", data) |
|
|
|
|
|
image = self.processor(image) |
|
|
|
|
|
outputs = self.model.generate(image) |
|
|
|
|
|
_, predicted = torch.max(outputs, 1) |
|
prediction = self.classes[predicted[0]] |
|
return {"class":prediction[0]} |