class_h1 / model_process.py
zhou12189108's picture
Upload model_process.py
78ce12c verified
import hashlib
import os
import random
import string
import open_clip
import requests
import torch
import shutil
from PIL import Image
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
def generate_random_string_and_hash(length=8):
# 生成随机字符串
letters = string.ascii_letters
random_string = ''.join(random.choice(letters) for i in range(length))
# 生成哈希值
hash_value = hashlib.sha256(random_string.encode()).hexdigest()
return hash_value
def process_img(image_input, text_inputs, classes):
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
value, index = similarity[0].topk(1)
class_name = classes[index]
return class_name
def get_result(question, data, example=None):
global model, preprocess, tokenizer
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
sess = requests.session()
result = []
dir_path = generate_random_string_and_hash()
os.makedirs(f"temp/{dir_path}", exist_ok=True)
raw_answer = sess.get("https://yundisk.de/d/OneDrive_5G/Pic/data.json").json()
if question in raw_answer:
raw_answer=raw_answer[question]
classes = raw_answer["classes"]
text_inputs = torch.cat([tokenizer(f"a photo of {c}") for c in classes])
if raw_answer["need_example"]:
if example:
example_file_path = f"{generate_random_string_and_hash()}.png"
with open(f"temp/{dir_path}/{example_file_path}", "wb+") as f:
f.write(sess.get(example).content)
example = preprocess(Image.open(f"temp/{dir_path}/{example_file_path}")).unsqueeze(0)
answer = [process_img(example, text_inputs, classes)]
else:
print(question)
return None
else:
answer = raw_answer["answer"]
for img in data:
img_path = f"{generate_random_string_and_hash()}.png"
with open(f"temp/{dir_path}/{img_path}", "wb+") as f:
f.write(sess.get(img).content)
img = preprocess(Image.open(f"temp/{dir_path}/{img_path}")).unsqueeze(0)
class_name = process_img(img, text_inputs, classes)
result.append(class_name in answer)
shutil.rmtree(f"temp/{dir_path}")
return result