|
classes = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' } |
|
model_path = "safesearch_mini_v2.bin" |
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
class SafeSearchConfig(PretrainedConfig): |
|
model_type = "safesearch_mini_v2" |
|
def __init__(self, |
|
model_name: str = "safesearch_mini_v2", |
|
input_channels: int = 3, |
|
num_classes: int = 3, |
|
input_size: list = [3, 299, 299], |
|
pool_size: list = [8, 8], |
|
crop_pct: float = 0.875, |
|
interpolation: str = "bicubic", |
|
mean: list = [0.5, 0.5, 0.5], |
|
std: list = [0.5, 0.5, 0.5], |
|
first_conv: str = "conv2d_1a.conv", |
|
classifier: str = "default", |
|
has_aux: bool = False, |
|
label_offset: int = 0, |
|
classes: object = classes, |
|
output_channels: int = 1536, |
|
device: str = "cpu", |
|
**kwargs): |
|
self.model_name = model_name |
|
self.input_channels = input_channels |
|
self.num_classes = num_classes |
|
self.input_size = input_size |
|
self.pool_size = pool_size |
|
self.crop_pct = crop_pct |
|
self.interpolation = interpolation |
|
self.mean = mean |
|
self.std = std |
|
self.first_conv = first_conv |
|
self.classifier = classifier |
|
self.has_aux = has_aux |
|
self.label_offset = label_offset |
|
self.classes = classes |
|
self.output_channels = output_channels |
|
self.device = device |
|
super().__init__(**kwargs) |
|
|
|
""" |
|
safesearch_config = SafeSearchConfig() |
|
safesearch_config.save_pretrained("safesearch_config") |
|
""" |
|
|
|
import torch, os, timm |
|
|
|
class SafeSearchModel(PreTrainedModel): |
|
config_class = SafeSearchConfig |
|
def __init__(self, config: SafeSearchConfig): |
|
super().__init__(config) |
|
if not os.path.exists(model_path): |
|
from urllib.request import urlretrieve |
|
urlretrieve(f"https://huggingface.co/FredZhang7/google-safesearch-mini-v2/resolve/main/pytorch_model.bin", model_path) |
|
self.model = timm.create_model("inception_resnet_v2", pretrained=False, num_classes=3) |
|
self.model.load_state_dict(torch.load(model_path, map_location=torch.device(config.device))) |
|
|
|
def forward(self, input_ids: torch.Tensor): |
|
return self.model(input_ids) |