Spaces:
Runtime error
Runtime error
import itertools | |
import numpy as np | |
from PIL import Image | |
from PIL import ImageSequence | |
from nltk import pos_tag, word_tokenize | |
from LLaMA2_Accessory.SPHINX import SPHINXModel | |
from gpt_combinator import caption_summary | |
class CaptionRefiner(): | |
def __init__(self, sample_num, add_detect=True, add_pos=True, add_attr=True, | |
openai_api_key=None, openai_api_base=None, | |
): | |
self.sample_num = sample_num | |
self.ADD_DETECTION_OBJ = add_detect | |
self.ADD_POS = add_pos | |
self.ADD_ATTR = add_attr | |
self.openai_api_key = openai_api_key | |
self.openai_api_base =openai_api_base | |
def video_load_split(self, video_path=None): | |
frame_img_list, sampled_img_list = [], [] | |
if ".gif" in video_path: | |
img = Image.open(video_path) | |
# process every frame in GIF from <PIL.GifImagePlugin.GifImageFile> to <PIL.JpegImagePlugin.JpegImageFile> | |
for frame in ImageSequence.Iterator(img): | |
frame_np = np.array(frame.copy().convert('RGB').getdata(),dtype=np.uint8).reshape(frame.size[1],frame.size[0],3) | |
frame_img = Image.fromarray(np.uint8(frame_np)) | |
frame_img_list.append(frame_img) | |
elif ".mp4" in video_path: | |
pass | |
# sample frames from the mp4/gif | |
for i in range(0, len(frame_img_list), int(len(frame_img_list)/self.sample_num)): | |
sampled_img_list.append(frame_img_list[i]) | |
return sampled_img_list # [<PIL.JpegImagePlugin.JpegImageFile>, ...] | |
def caption_refine(self, video_path, org_caption, model_path): | |
sampled_imgs = self.video_load_split(video_path) | |
model = SPHINXModel.from_pretrained( | |
pretrained_path=model_path, | |
with_visual=True | |
) | |
existing_objects, scene_description = [], [] | |
text = word_tokenize(org_caption) | |
existing_objects = [word for word,tag in pos_tag(text) if tag in ["NN", "NNS", "NNP"]] | |
if self.ADD_DETECTION_OBJ: | |
# Detect the objects and scene in the sampled images | |
qas = [["Where is this scene in the picture most likely to take place?", None]] | |
sc_response = model.generate_response(qas, sampled_imgs[0], max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
scene_description.append(sc_response) | |
# # Lacking accuracy | |
# for img in sampled_imgs: | |
# qas = [["Please detect the objects in the image.", None]] | |
# response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
# print(response) | |
object_attrs = [] | |
if self.ADD_ATTR: | |
# Detailed Description for all the objects in the sampled images | |
for obj in existing_objects: | |
obj_attr = [] | |
for img in sampled_imgs: | |
qas = [["Please describe the attribute of the {}, including color, position, etc".format(obj), None]] | |
response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
obj_attr.append(response) | |
object_attrs.append({obj : obj_attr}) | |
space_relations = [] | |
if self.ADD_POS: | |
obj_pairs = list(itertools.combinations(existing_objects, 2)) | |
# Description for the relationship between each object in the sample images | |
for obj_pair in obj_pairs: | |
qas = [["What is the spatial relationship between the {} and the {}? Please describe in lease than twenty words".format(obj_pair[0], obj_pair[1]), None]] | |
response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) | |
space_relations.append(response) | |
return dict( | |
org_caption = org_caption, | |
scene_description = scene_description, | |
existing_objects = existing_objects, | |
object_attrs = object_attrs, | |
space_relations = space_relations, | |
) | |
def gpt_summary(self, total_captions): | |
# combine all captions into a detailed long caption | |
detailed_caption = "" | |
if "org_caption" in total_captions.keys(): | |
detailed_caption += "In summary, "+ total_captions['org_caption'] | |
if "scene_description" in total_captions.keys(): | |
detailed_caption += "We first describe the whole scene. "+total_captions['scene_description'][-1] | |
if "existing_objects" in total_captions.keys(): | |
tmp_sentence = "There are multiple objects in the video, including " | |
for obj in total_captions['existing_objects']: | |
tmp_sentence += obj+", " | |
detailed_caption += tmp_sentence | |
# if "object_attrs" in total_captions.keys(): | |
# caption_summary( | |
# caption_list="", | |
# api_key=self.openai_api_key, | |
# api_base=self.openai_api_base, | |
# ) | |
if "space_relations" in total_captions.keys(): | |
tmp_sentence = "As for the spatial relationship. " | |
for sentence in total_captions['space_relations']: tmp_sentence += sentence | |
detailed_caption += tmp_sentence | |
detailed_caption = caption_summary(detailed_caption, self.open_api_key, self.open_api_base) | |
return detailed_caption |