Rohit8y's picture
updated time
1deeca6
import os
os.system('pip install git+https://github.com/openai/CLIP.git')
import gradio as gr
import datetime
import PIL
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import clip
import torch.nn as nn
from torchvision.transforms import transforms
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class OpenaAIClip(nn.Module):
def __init__(self, arch="resnet50", modality="image"):
super().__init__()
self.model = None
self.modality = modality
if arch == "resnet50":
self.model, _ = clip.load("RN50")
if self.modality == "image":
for name, param in self.model.named_parameters():
if "visual" in name:
#print("Unfreezing layer: ", name)
param.requires_grad = True
else:
param.requires_grad = False
self.fc = nn.Identity()
def forward(self, image, text=None):
image_features = self.model.encode_image(image)
if self.modality == "image+text":
text = clip.tokenize(text, truncate=True).to(device)
text_features = self.model.encode_text(text)
else:
return self.fc(image_features)
combined_features = torch.cat((image_features, text_features), dim=1)
return self.fc(combined_features)
def preprocessing(img, size):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
data_transforms = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
normalize])
img = data_transforms(img)
return img
def get_model(model_path,modality):
if modality == "Image":
model = OpenaAIClip(arch="resnet50", modality="image")
dim_mlp = 1024
fc_units = [512]
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
nn.Sigmoid())
elif modality == "Image+Text":
model = OpenaAIClip(arch="resnet50", modality="image+text")
dim_mlp = 2048
fc_units = [1024]
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
nn.Sigmoid())
checkpoint_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint_dict['state_dict'])
model.eval()
return model
def get_blip_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
#processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
#model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
model.eval()
return processor, model
def get_caption(image):
processor, model = get_blip_model()
inputs = processor(images=image, return_tensors="pt")
outputs = model.generate(**inputs.to(device))
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
def predict(img, caption):
now = datetime.datetime.now()
print(now)
if img is not None:
print(caption)
if caption is None or caption == "":
caption = get_caption(img)
print("Generated caption-->", caption)
else:
print("User input caption-->", caption)
img.save("models/"+str(now)+".png")
prediction=[]
models_list = ['models/clip-sd.pth', 'models/clip-glide.pth', 'models/clip-ld.pth']
modality = "Image+Text"
for i, model_path in enumerate(models_list):
model = get_model(model_path, modality)
tensor = preprocessing(img, 224)
input_tensor = tensor.view(1, 3, 224, 224)
with torch.no_grad():
out = model(input_tensor, caption)
print(models_list[i], ' ----> ', out)
prediction.append(out.item())
# Count the number of predictions that are greater than or equal to 0.5
count_ones = sum(1 for p in prediction if p >= 0.5)
if count_ones > len(prediction) / 2:
return "Fake Image"
else:
return "Real Image"
else:
print("Alert: Input image missing")
return "Alert: Input image missing"
# Create Gradio interface
image_input = gr.Image(type="pil", label="Input Image")
text_input = gr.Textbox(label="Caption for image (Optional)")
iface = gr.Interface(fn=predict,
inputs=[image_input, text_input],
outputs=gr.Label(),
examples=[["examples/trump-fake.jpeg", "Donald Trump being arrested by authorities."],
["examples/astronaut_space.png", "An astronaut playing basketball with a cat in space, digital art"]])
iface.launch()