import collections import os import pickle from argparse import Namespace import numpy as np import torch from PIL import Image from torch import cosine_similarity from transformers import AutoTokenizer, AutoModel def download_models(): # Import our models. The package will take care of downloading the models automatically model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False, init_embeddings_model=None) model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args) return model class Text: def __init__(self, text_dir, model, num_steps, text_image_pkl_path=None, dict_text_pkl_path=None, pkl_path=None, dict_path=None, image_path=None, maps_path=None): self.dict_text_pkl_path = dict_text_pkl_path self.text_image_pkl_path = text_image_pkl_path self.text_dir = text_dir self.model = model self.num_steps = num_steps self.pkl_path = pkl_path self.dict_path = dict_path self.image_path = image_path self.maps_path = maps_path def get_embedding(self, texts): tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert") model = download_models() # 截断 # str or strList texts = texts if isinstance(texts, list) else [texts] for i in range(len(texts)): if len(texts[i]) > self.num_steps: texts[i] = texts[i][:self.num_steps] # Tokenize the texts inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt") # Extract the embeddings # Get the embeddings with torch.no_grad(): embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output return embeddings[0] if len(texts) == 1 else embeddings def read_text(self, save_embeddings=False, save_maps=False): """抽取、预存""" text_embeddings = collections.defaultdict() text_keys = [] dirs = os.listdir(self.text_dir) data = [] texts = [] id = 0 for dir in dirs: with open(self.text_dir + '/' + dir, 'r') as fr: for line in fr.readlines(): category = collections.defaultdict(str) ch = ':' if ':' in line else ':' if '旁白' in line: text = line.strip().split(ch)[1].strip() else: text = ''.join(list(line.strip().split(ch)[1])[1:-1]) # 提取「」内的文本 if text in text_keys: # 避免重复的text,导致embeds 和 maps形状不一致 continue text_keys.append(text) if save_maps: category["titles"] = dir.split('.')[0] category["id"] = str(id) category["text"] = text id = id + 1 data.append(dict(category)) texts.append(text) embeddings = self.get_embedding(texts) if save_embeddings: for text, embed in zip(texts, embeddings): text_embeddings[text] = self.get_embedding(text) if save_embeddings: self.store(self.pkl_path, text_embeddings) if save_maps: self.store(self.maps_path, data) return text_embeddings, data def load(self, load_pkl=False, load_maps=False, load_dict_text=False, load_text_image=False): if self.pkl_path and load_pkl: with open(self.pkl_path, 'rb') as f: return pickle.load(f) elif self.maps_path and load_maps: with open(self.maps_path, 'rb') as f: return pickle.load(f) elif self.dict_text_pkl_path and load_dict_text: with open(self.dict_text_pkl_path, 'rb') as f: return pickle.load(f) elif self.text_image_pkl_path and load_text_image: with open(self.text_image_pkl_path, 'rb') as f: return pickle.load(f) else: print("No pkl_path") def get_cosine_similarity(self, texts, get_image=False, get_texts=False): """ 计算文本列表的相似度避免重复计算query_similarity texts[0] = query """ if get_image: pkl = self.load(load_dict_text=True) elif get_texts: pkl = self.load(load_pkl=True) else: pkl = {} embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536) for text, embed in zip(texts, embeddings): pkl[text] = embed query_embedding = self.get_embedding(texts[0]).reshape(1, -1) texts_embeddings = np.array([value.numpy().reshape(-1, 1536) for value in pkl.values()]).squeeze(1) return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings)) def store(self, path, data): with open(path, 'wb+') as f: pickle.dump(data, f) def text_to_image(self, text, save_dict_text=False): """ 给定文本出图片 计算query 和 texts 的相似度,取最高的作为new_query 查询image 到text_image_dict 读取图片名 然后到images里面加载该图片然后返回 """ if save_dict_text: text_image = {} with open(self.dict_path, 'r') as f: data = f.readlines() for sub_text, image in zip(data[::2], data[1::2]): text_image[sub_text.strip()] = image.strip() self.store(self.text_image_pkl_path, text_image) keys_embeddings = {} embeddings = self.get_embedding(list(text_image.keys())) for key, embed in zip(text_image.keys(), embeddings): keys_embeddings[key] = embed self.store(self.dict_text_pkl_path, keys_embeddings) if self.dict_path and self.image_path: # 加载 text-imageName text_image = self.load(load_text_image=True) keys = list(text_image.keys()) keys.insert(0, text) query_similarity = self.get_cosine_similarity(keys, get_image=True) key_index = query_similarity.argmax(dim=0) text = list(text_image.keys())[key_index] image = text_image[text] + '.jpg' if image in os.listdir(self.image_path): res = Image.open(self.image_path + '/' + image) # res.show() return res else: print("Image doesn't exist") else: print("No path") def text_to_text(self, text): pkl = self.load(load_pkl=True) texts = list(pkl.keys()) texts.insert(0, text) texts_similarity = self.get_cosine_similarity(texts, get_texts=True) key_index = texts_similarity.argmax(dim=0).item() value = list(pkl.keys())[key_index] return value # if __name__ == '__main__': # pkl_path = './pkl/texts.pkl' # maps_path = './pkl/maps.pkl' # text_image_pkl_path='./pkl/text_image.pkl' # dict_path = "../characters/haruhi/text_image_dict.txt" # dict_text_pkl_path = './pkl/dict_text.pkl' # image_path = "../characters/haruhi/images" # text_dir = "../characters/haruhi/texts" # model = download_models() # text = Text(text_dir, text_image_pkl_path=text_image_pkl_path, maps_path=maps_path, # dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path, # dict_path=dict_path, image_path=image_path) # text.read_text(save_maps=True, save_embeddings=True) # data = text.load(load_pkl=True) # sub_text = "你好!" # image = text.text_to_image(sub_text) # print(image) # sub_texts = ["hello", "你好"] # print(text.get_cosine_similarity(sub_texts)) # value = text.text_to_text(sub_text) # print(value)