Spaces:
Runtime error
Runtime error
File size: 8,138 Bytes
5f735a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import collections
import os
import pickle
from argparse import Namespace
import numpy as np
import torch
from PIL import Image
from torch import cosine_similarity
from transformers import AutoTokenizer, AutoModel
def download_models():
# Import our models. The package will take care of downloading the models automatically
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
init_embeddings_model=None)
model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args)
return model
class Text:
def __init__(self, text_dir, model, num_steps, text_image_pkl_path=None, dict_text_pkl_path=None, pkl_path=None, dict_path=None, image_path=None, maps_path=None):
self.dict_text_pkl_path = dict_text_pkl_path
self.text_image_pkl_path = text_image_pkl_path
self.text_dir = text_dir
self.model = model
self.num_steps = num_steps
self.pkl_path = pkl_path
self.dict_path = dict_path
self.image_path = image_path
self.maps_path = maps_path
def get_embedding(self, texts):
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model = download_models()
# 截断
# str or strList
texts = texts if isinstance(texts, list) else [texts]
for i in range(len(texts)):
if len(texts[i]) > self.num_steps:
texts[i] = texts[i][:self.num_steps]
# Tokenize the texts
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
# Extract the embeddings
# Get the embeddings
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
return embeddings[0] if len(texts) == 1 else embeddings
def read_text(self, save_embeddings=False, save_maps=False):
"""抽取、预存"""
text_embeddings = collections.defaultdict()
text_keys = []
dirs = os.listdir(self.text_dir)
data = []
texts = []
id = 0
for dir in dirs:
with open(self.text_dir + '/' + dir, 'r') as fr:
for line in fr.readlines():
category = collections.defaultdict(str)
ch = ':' if ':' in line else ':'
if '旁白' in line:
text = line.strip().split(ch)[1].strip()
else:
text = ''.join(list(line.strip().split(ch)[1])[1:-1]) # 提取「」内的文本
if text in text_keys: # 避免重复的text,导致embeds 和 maps形状不一致
continue
text_keys.append(text)
if save_maps:
category["titles"] = dir.split('.')[0]
category["id"] = str(id)
category["text"] = text
id = id + 1
data.append(dict(category))
texts.append(text)
embeddings = self.get_embedding(texts)
if save_embeddings:
for text, embed in zip(texts, embeddings):
text_embeddings[text] = self.get_embedding(text)
if save_embeddings:
self.store(self.pkl_path, text_embeddings)
if save_maps:
self.store(self.maps_path, data)
return text_embeddings, data
def load(self, load_pkl=False, load_maps=False, load_dict_text=False, load_text_image=False):
if self.pkl_path and load_pkl:
with open(self.pkl_path, 'rb') as f:
return pickle.load(f)
elif self.maps_path and load_maps:
with open(self.maps_path, 'rb') as f:
return pickle.load(f)
elif self.dict_text_pkl_path and load_dict_text:
with open(self.dict_text_pkl_path, 'rb') as f:
return pickle.load(f)
elif self.text_image_pkl_path and load_text_image:
with open(self.text_image_pkl_path, 'rb') as f:
return pickle.load(f)
else:
print("No pkl_path")
def get_cosine_similarity(self, texts, get_image=False, get_texts=False):
"""
计算文本列表的相似度避免重复计算query_similarity
texts[0] = query
"""
if get_image:
pkl = self.load(load_dict_text=True)
elif get_texts:
pkl = self.load(load_pkl=True)
else:
pkl = {}
embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536)
for text, embed in zip(texts, embeddings):
pkl[text] = embed
query_embedding = self.get_embedding(texts[0]).reshape(1, -1)
texts_embeddings = np.array([value.numpy().reshape(-1, 1536) for value in pkl.values()]).squeeze(1)
return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings))
def store(self, path, data):
with open(path, 'wb+') as f:
pickle.dump(data, f)
def text_to_image(self, text, save_dict_text=False):
"""
给定文本出图片
计算query 和 texts 的相似度,取最高的作为new_query 查询image
到text_image_dict 读取图片名
然后到images里面加载该图片然后返回
"""
if save_dict_text:
text_image = {}
with open(self.dict_path, 'r') as f:
data = f.readlines()
for sub_text, image in zip(data[::2], data[1::2]):
text_image[sub_text.strip()] = image.strip()
self.store(self.text_image_pkl_path, text_image)
keys_embeddings = {}
embeddings = self.get_embedding(list(text_image.keys()))
for key, embed in zip(text_image.keys(), embeddings):
keys_embeddings[key] = embed
self.store(self.dict_text_pkl_path, keys_embeddings)
if self.dict_path and self.image_path:
# 加载 text-imageName
text_image = self.load(load_text_image=True)
keys = list(text_image.keys())
keys.insert(0, text)
query_similarity = self.get_cosine_similarity(keys, get_image=True)
key_index = query_similarity.argmax(dim=0)
text = list(text_image.keys())[key_index]
image = text_image[text] + '.jpg'
if image in os.listdir(self.image_path):
res = Image.open(self.image_path + '/' + image)
# res.show()
return res
else:
print("Image doesn't exist")
else:
print("No path")
def text_to_text(self, text):
pkl = self.load(load_pkl=True)
texts = list(pkl.keys())
texts.insert(0, text)
texts_similarity = self.get_cosine_similarity(texts, get_texts=True)
key_index = texts_similarity.argmax(dim=0).item()
value = list(pkl.keys())[key_index]
return value
# if __name__ == '__main__':
# pkl_path = './pkl/texts.pkl'
# maps_path = './pkl/maps.pkl'
# text_image_pkl_path='./pkl/text_image.pkl'
# dict_path = "../characters/haruhi/text_image_dict.txt"
# dict_text_pkl_path = './pkl/dict_text.pkl'
# image_path = "../characters/haruhi/images"
# text_dir = "../characters/haruhi/texts"
# model = download_models()
# text = Text(text_dir, text_image_pkl_path=text_image_pkl_path, maps_path=maps_path,
# dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path,
# dict_path=dict_path, image_path=image_path)
# text.read_text(save_maps=True, save_embeddings=True)
# data = text.load(load_pkl=True)
# sub_text = "你好!"
# image = text.text_to_image(sub_text)
# print(image)
# sub_texts = ["hello", "你好"]
# print(text.get_cosine_similarity(sub_texts))
# value = text.text_to_text(sub_text)
# print(value)
|