File size: 6,208 Bytes
744eb4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import argparse
from transformers import AutoTokenizer
import torch
import os
from pointllm.conversation import conv_templates, SeparatorStyle
from pointllm.utils import disable_torch_init
from pointllm.model import *
from pointllm.model.utils import KeywordsStoppingCriteria
from pointllm.data import load_objaverse_point_cloud
import os
def load_point_cloud(args):
object_id = args.object_id
print(f"[INFO] Loading point clouds using object_id: {object_id}")
point_cloud = load_objaverse_point_cloud(args.data_path, object_id, pointnum=8192, use_color=True)
return object_id, torch.from_numpy(point_cloud).unsqueeze_(0).to(torch.float32)
def init_model(args):
# Model
disable_torch_init()
model_path = args.model_path
print(f'[INFO] Model name: {model_path}')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = PointLLMLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=False, use_cache=True, torch_dtype=args.torch_dtype).cuda()
model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
model.eval()
mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
# Add special tokens ind to model.point_config
point_backbone_config = model.get_model().point_backbone_config
if mm_use_point_start_end:
if "v1" in model_path.lower():
conv_mode = "vicuna_v1_1"
else:
raise NotImplementedError
conv = conv_templates[conv_mode].copy()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
point_token_len = point_backbone_config['point_token_len']
default_point_patch_token = point_backbone_config['default_point_patch_token']
default_point_start_token = point_backbone_config['default_point_start_token']
default_point_end_token = point_backbone_config['default_point_end_token']
# The while loop will keep running until the user decides to quit
print("[INFO] Starting conversation... Enter 'q' to exit the program and enter 'exit' to exit the current conversation.")
while True:
print("-" * 80)
# Prompt for object_id
object_id = input("[INFO] Please enter the object_id or 'q' to quit: ")
# Check if the user wants to quit
if object_id.lower() == 'q':
print("[INFO] Quitting...")
break
else:
# print info
print(f"[INFO] Chatting with object_id: {object_id}.")
# Update args with new object_id
args.object_id = object_id.strip()
# Load the point cloud data
try:
id, point_clouds = load_point_cloud(args)
except Exception as e:
print(f"[ERROR] {e}")
continue
point_clouds = point_clouds.cuda().to(args.torch_dtype)
# Reset the conversation template
conv.reset()
print("-" * 80)
# Start a loop for multiple rounds of dialogue
for i in range(100):
# This if-else block ensures the initial question from the user is included in the conversation
qs = input(conv.roles[0] + ': ')
if qs == 'exit':
break
if i == 0:
if mm_use_point_start_end:
qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
else:
qs = default_point_patch_token * point_token_len + '\n' + qs
# Append the new message to the conversation history
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
stop_str = keywords[0]
with torch.inference_mode():
output_ids = model.generate(
input_ids,
point_clouds=point_clouds,
do_sample=True,
temperature=1.0,
top_k=50,
max_length=2048,
top_p=0.95,
stopping_criteria=[stopping_criteria])
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
# Append the model's response to the conversation history
conv.pop_last_none_message()
conv.append_message(conv.roles[1], outputs)
print(f'{conv.roles[1]}: {outputs}\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, \
default="RunsenXu/PointLLM_7B_v1.2")
parser.add_argument("--data_path", type=str, default="data/objaverse_data")
parser.add_argument("--torch_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"])
args = parser.parse_args()
dtype_mapping = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
args.torch_dtype = dtype_mapping[args.torch_dtype]
model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv) |