Spaces:
Paused
Paused
| import hashlib | |
| import os | |
| import random | |
| import string | |
| import open_clip | |
| import requests | |
| import torch | |
| import shutil | |
| from PIL import Image | |
| model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') | |
| tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
| def generate_random_string_and_hash(length=8): | |
| # 生成随机字符串 | |
| letters = string.ascii_letters | |
| random_string = ''.join(random.choice(letters) for i in range(length)) | |
| # 生成哈希值 | |
| hash_value = hashlib.sha256(random_string.encode()).hexdigest() | |
| return hash_value | |
| def process_img(image_input, text_inputs, classes): | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_input) | |
| text_features = model.encode_text(text_inputs) | |
| # Pick the top 5 most similar labels for the image | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| value, index = similarity[0].topk(1) | |
| class_name = classes[index] | |
| return class_name | |
| def get_result(question, data, example=None): | |
| global model, preprocess, tokenizer | |
| model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') | |
| tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
| sess = requests.session() | |
| result = [] | |
| dir_path = generate_random_string_and_hash() | |
| os.makedirs(f"temp/{dir_path}", exist_ok=True) | |
| raw_answer = sess.get("https://yundisk.de/d/OneDrive_5G/Pic/data.json").json() | |
| if question in raw_answer: | |
| raw_answer=raw_answer[question] | |
| classes = raw_answer["classes"] | |
| text_inputs = torch.cat([tokenizer(f"a photo of {c}") for c in classes]) | |
| if raw_answer["need_example"]: | |
| if example: | |
| example_file_path = f"{generate_random_string_and_hash()}.png" | |
| with open(f"temp/{dir_path}/{example_file_path}", "wb+") as f: | |
| f.write(sess.get(example).content) | |
| example = preprocess(Image.open(f"temp/{dir_path}/{example_file_path}")).unsqueeze(0) | |
| answer = [process_img(example, text_inputs, classes)] | |
| else: | |
| print(question) | |
| return None | |
| else: | |
| answer = raw_answer["answer"] | |
| for img in data: | |
| img_path = f"{generate_random_string_and_hash()}.png" | |
| with open(f"temp/{dir_path}/{img_path}", "wb+") as f: | |
| f.write(sess.get(img).content) | |
| img = preprocess(Image.open(f"temp/{dir_path}/{img_path}")).unsqueeze(0) | |
| class_name = process_img(img, text_inputs, classes) | |
| result.append(class_name in answer) | |
| shutil.rmtree(f"temp/{dir_path}") | |
| return result | |