Spaces:
Runtime error
Runtime error
File size: 5,319 Bytes
8a09a62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import argparse
import torch
import sys
import os
# 添加当前命令行运行的目录到 sys.path
sys.path.append(os.getcwd()+"/dialoggen")
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
import requests
from PIL import Image
from io import BytesIO
import re
def image_parser(image_file, sep=','):
out = image_file.split(sep)
return out
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out
def init_dialoggen_model(model_path, model_base=None):
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, model_base, model_name, llava_type_model=True)
return {"tokenizer": tokenizer,
"model": model,
"image_processor": image_processor}
def eval_model(models,
query='详细描述一下这张图片',
image_file=None,
sep=',',
temperature=0.2,
top_p=None,
num_beams=1,
max_new_tokens=512,
):
# Model
disable_torch_init()
qs = query
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if models["model"].config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if models["model"].config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
conv = conv_templates['llava_v1'].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if image_file is not None:
image_files = image_parser(image_file, sep=sep)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
images,
models["image_processor"],
models["model"].config
).to(models["model"].device, dtype=torch.float16)
else:
# fomatted input as training data
image_sizes = [(1024, 1024)]
images_tensor = torch.zeros(1, 5, 3, models["image_processor"].crop_size["height"], models["image_processor"].crop_size["width"])
images_tensor = images_tensor.to(models["model"].device, dtype=torch.float16)
input_ids = (
tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
with torch.inference_mode():
output_ids = models["model"].generate(
input_ids,
images=images_tensor,
image_sizes=image_sizes,
do_sample=True if temperature > 0 else False,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
use_cache=True,
)
outputs = models["tokenizer"].batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return outputs
def remove_prefix(text):
if text.startswith("<画图>"):
return text[len("<画图>"):], True
elif text.startswith("对不起"):
# 拒绝画图
return "", False
else:
return text, True
class DialogGen(object):
def __init__(self, model_path):
self.models = init_dialoggen_model(model_path)
self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}"
def __call__(self, prompt):
enhanced_prompt = eval_model(
models=self.models,
query=self.query_template.format(prompt),
image_file=None,
)
enhanced_prompt, compliance = remove_prefix(enhanced_prompt)
if not compliance:
return False, ""
return True, enhanced_prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen')
parser.add_argument('--prompt', type=str, default='画一只小猫')
parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg'
args = parser.parse_args()
query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}"
models = init_dialoggen_model(args.model_path)
res = eval_model(models,
query=query,
image_file=args.image_file,
)
print(res)
|