Spaces:
Runtime error
Runtime error
File size: 5,956 Bytes
89ec2f0 5f14561 89ec2f0 eef0e0e 186a718 f976bf4 faf1772 eef0e0e faf1772 2575ee4 faf1772 8c9ecba faf1772 2575ee4 faf1772 2c2b39f faf1772 0746c62 da948f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import clip
import gc
import numpy as np
import os
import pandas as pd
import requests
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
#from IPython.display import display
from PIL import Image
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from BLIP.models.blip import blip_decoder
import gradio as gr
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
blip_model.eval()
blip_model = blip_model.to(device)
def generate_caption(pil_image):
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(device)
with torch.no_grad():
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def rank(model, image_features, text_array, top_count=1):
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array])#.cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def interrogate(cover):
image = Image.fromarray(cover)
#image = cover
models = models1
#caption = generate_caption(Image.fromarray(cover))
caption = generate_caption(image)
if len(models) == 0:
#print(f"\n\n{caption}")
return
table = []
bests = [[('',0)]]*5
for model_name in models:
#print(f"Interrogating with {model_name}...")
model, preprocess = clip.load(model_name)
#model.cuda().eval()
images = preprocess(image).unsqueeze(0)#.cuda()
with torch.no_grad():
image_features = model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
ranks = [
rank(model, image_features, mediums),
rank(model, image_features, ["by "+artist for artist in artists]),
rank(model, image_features, trending_list),
rank(model, image_features, movements),
rank(model, image_features, flavors, top_count=3)
]
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
row = [model_name]
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
table.append(row)
del model
gc.collect()
#display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
medium = bests[0][0][0]
if caption.startswith(medium):
return(f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
#print(f"{caption} {bests[3][0][0]}, {flaves}")
else:
return(f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
#print(f"{caption} {bests[3][0][0]}, {flaves}")
data_path = "./clip-interrogator/data/"
artists = load_list(os.path.join(data_path, 'artists.txt'))
flavors = load_list(os.path.join(data_path, 'flavors.txt'))
mediums = load_list(os.path.join(data_path, 'mediums.txt'))
movements = load_list(os.path.join(data_path, 'movements.txt'))
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])
models1 = ['ViT-B/32']
width = 130
height = 180
cover = gr.inputs.Image(shape=(width, height), label='Upload cover image to classify')
label = gr.outputs.Label(label='Model prediction')
examples=["00064.jpg","00068.jpg", "00069.jpg"]
title="Image2Text-CLIP Application"
description='''
此文本是使用 OpenAI CLIP 模型針對各種藝術家、媒介和風格測試給定圖像,轉化出AI對於圖像的理解.
<th>
<iframe width="560" height="315" src="https://www.youtube.com/embed/u0HG77RNhPE" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
</th>
### 以下請輸入指定圖片, 或是選擇以下3個樣本
'''
gr.Interface(fn=interrogate,inputs=cover,outputs=label,examples=examples,title=title,description=description).launch()#(share=True) |