Spaces:
Sleeping
Sleeping
import hashlib | |
import os | |
import random | |
import string | |
import open_clip | |
import requests | |
import torch | |
import shutil | |
from PIL import Image | |
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') | |
tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
def generate_random_string_and_hash(length=8): | |
# 生成随机字符串 | |
letters = string.ascii_letters | |
random_string = ''.join(random.choice(letters) for i in range(length)) | |
# 生成哈希值 | |
hash_value = hashlib.sha256(random_string.encode()).hexdigest() | |
return hash_value | |
def process_img(image_input, text_inputs, classes): | |
with torch.no_grad(): | |
image_features = model.encode_image(image_input) | |
text_features = model.encode_text(text_inputs) | |
# Pick the top 5 most similar labels for the image | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
value, index = similarity[0].topk(1) | |
class_name = classes[index] | |
return class_name | |
def get_result(question, data, example=None): | |
global model, preprocess, tokenizer | |
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') | |
tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
sess = requests.session() | |
result = [] | |
dir_path = generate_random_string_and_hash() | |
os.makedirs(f"temp/{dir_path}", exist_ok=True) | |
raw_answer = sess.get("https://yundisk.de/d/OneDrive_5G/Pic/data.json").json() | |
if question in raw_answer: | |
raw_answer=raw_answer[question] | |
classes = raw_answer["classes"] | |
text_inputs = torch.cat([tokenizer(f"a photo of {c}") for c in classes]) | |
if raw_answer["need_example"]: | |
if example: | |
example_file_path = f"{generate_random_string_and_hash()}.png" | |
with open(f"temp/{dir_path}/{example_file_path}", "wb+") as f: | |
f.write(sess.get(example).content) | |
example = preprocess(Image.open(f"temp/{dir_path}/{example_file_path}")).unsqueeze(0) | |
answer = [process_img(example, text_inputs, classes)] | |
else: | |
print(question) | |
return None | |
else: | |
answer = raw_answer["answer"] | |
for img in data: | |
img_path = f"{generate_random_string_and_hash()}.png" | |
with open(f"temp/{dir_path}/{img_path}", "wb+") as f: | |
f.write(sess.get(img).content) | |
img = preprocess(Image.open(f"temp/{dir_path}/{img_path}")).unsqueeze(0) | |
class_name = process_img(img, text_inputs, classes) | |
result.append(class_name in answer) | |
shutil.rmtree(f"temp/{dir_path}") | |
return result | |