Spaces:
Running
Running
import os | |
os.system('pip install git+https://github.com/openai/CLIP.git') | |
import gradio as gr | |
import datetime | |
import PIL | |
from PIL import Image | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
import torch | |
import clip | |
import torch.nn as nn | |
from torchvision.transforms import transforms | |
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class OpenaAIClip(nn.Module): | |
def __init__(self, arch="resnet50", modality="image"): | |
super().__init__() | |
self.model = None | |
self.modality = modality | |
if arch == "resnet50": | |
self.model, _ = clip.load("RN50") | |
if self.modality == "image": | |
for name, param in self.model.named_parameters(): | |
if "visual" in name: | |
#print("Unfreezing layer: ", name) | |
param.requires_grad = True | |
else: | |
param.requires_grad = False | |
self.fc = nn.Identity() | |
def forward(self, image, text=None): | |
image_features = self.model.encode_image(image) | |
if self.modality == "image+text": | |
text = clip.tokenize(text, truncate=True).to(device) | |
text_features = self.model.encode_text(text) | |
else: | |
return self.fc(image_features) | |
combined_features = torch.cat((image_features, text_features), dim=1) | |
return self.fc(combined_features) | |
def preprocessing(img, size): | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
data_transforms = transforms.Compose([ | |
transforms.Resize((size, size)), | |
transforms.ToTensor(), | |
normalize]) | |
img = data_transforms(img) | |
return img | |
def get_model(model_path,modality): | |
if modality == "Image": | |
model = OpenaAIClip(arch="resnet50", modality="image") | |
dim_mlp = 1024 | |
fc_units = [512] | |
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1), | |
nn.Sigmoid()) | |
elif modality == "Image+Text": | |
model = OpenaAIClip(arch="resnet50", modality="image+text") | |
dim_mlp = 2048 | |
fc_units = [1024] | |
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1), | |
nn.Sigmoid()) | |
checkpoint_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint_dict['state_dict']) | |
model.eval() | |
return model | |
def get_blip_model(): | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
#processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
#model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b") | |
model.eval() | |
return processor, model | |
def get_caption(image): | |
processor, model = get_blip_model() | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model.generate(**inputs.to(device)) | |
caption = processor.decode(outputs[0], skip_special_tokens=True) | |
return caption | |
def predict(img, caption): | |
now = datetime.datetime.now() | |
print(now) | |
if img is not None: | |
print(caption) | |
if caption is None or caption == "": | |
caption = get_caption(img) | |
print("Generated caption-->", caption) | |
else: | |
print("User input caption-->", caption) | |
img.save("models/"+str(now)+".png") | |
prediction=[] | |
models_list = ['models/clip-sd.pth', 'models/clip-glide.pth', 'models/clip-ld.pth'] | |
modality = "Image+Text" | |
for i, model_path in enumerate(models_list): | |
model = get_model(model_path, modality) | |
tensor = preprocessing(img, 224) | |
input_tensor = tensor.view(1, 3, 224, 224) | |
with torch.no_grad(): | |
out = model(input_tensor, caption) | |
print(models_list[i], ' ----> ', out) | |
prediction.append(out.item()) | |
# Count the number of predictions that are greater than or equal to 0.5 | |
count_ones = sum(1 for p in prediction if p >= 0.5) | |
if count_ones > len(prediction) / 2: | |
return "Fake Image" | |
else: | |
return "Real Image" | |
else: | |
print("Alert: Input image missing") | |
return "Alert: Input image missing" | |
# Create Gradio interface | |
image_input = gr.Image(type="pil", label="Input Image") | |
text_input = gr.Textbox(label="Caption for image (Optional)") | |
iface = gr.Interface(fn=predict, | |
inputs=[image_input, text_input], | |
outputs=gr.Label(), | |
examples=[["examples/trump-fake.jpeg", "Donald Trump being arrested by authorities."], | |
["examples/astronaut_space.png", "An astronaut playing basketball with a cat in space, digital art"]]) | |
iface.launch() | |