Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from io import BytesIO | |
from scipy.ndimage import gaussian_filter | |
from model import CLIPViTL14Model | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
MEAN = { | |
"imagenet":[0.485, 0.456, 0.406], | |
"clip":[0.48145466, 0.4578275, 0.40821073] | |
} | |
STD = { | |
"imagenet":[0.229, 0.224, 0.225], | |
"clip":[0.26862954, 0.26130258, 0.27577711] | |
} | |
def png2jpg(img, quality): | |
out = BytesIO() | |
img.save(out, format='jpeg', quality=quality) # ranging from 0-95, 75 is default | |
img = Image.open(out) | |
# load from memory before ByteIO closes | |
img = np.array(img) | |
out.close() | |
return Image.fromarray(img) | |
def gaussian_blur(img, sigma): | |
img = np.array(img) | |
gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) | |
gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) | |
gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) | |
return Image.fromarray(img) | |
def plot_pie_chart(false_prob, save_path): | |
labels = ['Real', 'Fake'] | |
probabilities = [1-false_prob, false_prob] | |
colors = ['#ADD8E6', '#FFC0CB'] # 浅蓝色和浅红色 | |
explode = (0.1, 0) # 设置偏移量 | |
plt.figure(figsize=(6, 6)) | |
plt.pie(probabilities, labels=labels, colors=colors, explode=explode, autopct='%1.1f%%', startangle=90) | |
plt.axis('equal') | |
plt.savefig(save_path) | |
def detect( | |
img_path: str, | |
save_path: str, | |
pretrained_path: str=None, | |
stat_from: str="clip", | |
gaussian_sigma: int=None, | |
jpeg_quality: int=None, | |
device: str="cpu" | |
): | |
img = Image.open(img_path).convert("RGB") | |
if gaussian_sigma is not None: | |
img = gaussian_blur(img, gaussian_sigma) | |
if jpeg_quality is not None: | |
img = png2jpg(img, jpeg_quality) | |
# transform | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
# transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), | |
]) | |
img = transform(img) | |
img: torch.Tensor | |
if img.ndim == 3: | |
img = img.unsqueeze(dim=0) | |
img = img.to(device=device) | |
model = CLIPViTL14Model() | |
if pretrained_path: | |
state_dict = torch.load(pretrained_path, map_location=device) | |
model.fc.load_state_dict(state_dict) | |
model.eval() | |
model.to(device=device) | |
probs = model(img).sigmoid().flatten().tolist()[0] | |
plot_pie_chart(probs, save_path) |