|
import argparse |
|
from transformers import AutoTokenizer |
|
import torch |
|
import os |
|
from pointllm.conversation import conv_templates, SeparatorStyle |
|
from pointllm.utils import disable_torch_init |
|
from pointllm.model import * |
|
from pointllm.model.utils import KeywordsStoppingCriteria |
|
|
|
from pointllm.data import load_objaverse_point_cloud |
|
|
|
import os |
|
|
|
def load_point_cloud(args): |
|
object_id = args.object_id |
|
print(f"[INFO] Loading point clouds using object_id: {object_id}") |
|
point_cloud = load_objaverse_point_cloud(args.data_path, object_id, pointnum=8192, use_color=True) |
|
|
|
return object_id, torch.from_numpy(point_cloud).unsqueeze_(0).to(torch.float32) |
|
|
|
def init_model(args): |
|
|
|
disable_torch_init() |
|
|
|
model_path = args.model_path |
|
print(f'[INFO] Model name: {model_path}') |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = PointLLMLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=False, use_cache=True, torch_dtype=args.torch_dtype).cuda() |
|
model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer) |
|
|
|
model.eval() |
|
|
|
mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False) |
|
|
|
point_backbone_config = model.get_model().point_backbone_config |
|
|
|
if mm_use_point_start_end: |
|
if "v1" in model_path.lower(): |
|
conv_mode = "vicuna_v1_1" |
|
else: |
|
raise NotImplementedError |
|
|
|
conv = conv_templates[conv_mode].copy() |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
|
|
return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv |
|
|
|
def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv): |
|
point_token_len = point_backbone_config['point_token_len'] |
|
default_point_patch_token = point_backbone_config['default_point_patch_token'] |
|
default_point_start_token = point_backbone_config['default_point_start_token'] |
|
default_point_end_token = point_backbone_config['default_point_end_token'] |
|
|
|
print("[INFO] Starting conversation... Enter 'q' to exit the program and enter 'exit' to exit the current conversation.") |
|
while True: |
|
print("-" * 80) |
|
|
|
object_id = input("[INFO] Please enter the object_id or 'q' to quit: ") |
|
|
|
|
|
if object_id.lower() == 'q': |
|
print("[INFO] Quitting...") |
|
break |
|
else: |
|
|
|
print(f"[INFO] Chatting with object_id: {object_id}.") |
|
|
|
|
|
args.object_id = object_id.strip() |
|
|
|
|
|
try: |
|
id, point_clouds = load_point_cloud(args) |
|
except Exception as e: |
|
print(f"[ERROR] {e}") |
|
continue |
|
point_clouds = point_clouds.cuda().to(args.torch_dtype) |
|
|
|
|
|
conv.reset() |
|
|
|
print("-" * 80) |
|
|
|
|
|
for i in range(100): |
|
|
|
qs = input(conv.roles[0] + ': ') |
|
if qs == 'exit': |
|
break |
|
|
|
if i == 0: |
|
if mm_use_point_start_end: |
|
qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs |
|
else: |
|
qs = default_point_patch_token * point_token_len + '\n' + qs |
|
|
|
|
|
conv.append_message(conv.roles[0], qs) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
inputs = tokenizer([prompt]) |
|
|
|
input_ids = torch.as_tensor(inputs.input_ids).cuda() |
|
|
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
|
stop_str = keywords[0] |
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
point_clouds=point_clouds, |
|
do_sample=True, |
|
temperature=1.0, |
|
top_k=50, |
|
max_length=2048, |
|
top_p=0.95, |
|
stopping_criteria=[stopping_criteria]) |
|
|
|
input_token_len = input_ids.shape[1] |
|
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() |
|
if n_diff_input_output > 0: |
|
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') |
|
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] |
|
outputs = outputs.strip() |
|
if outputs.endswith(stop_str): |
|
outputs = outputs[:-len(stop_str)] |
|
outputs = outputs.strip() |
|
|
|
|
|
conv.pop_last_none_message() |
|
conv.append_message(conv.roles[1], outputs) |
|
print(f'{conv.roles[1]}: {outputs}\n') |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_name", type=str, \ |
|
default="RunsenXu/PointLLM_7B_v1.2") |
|
|
|
parser.add_argument("--data_path", type=str, default="data/objaverse_data") |
|
parser.add_argument("--torch_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) |
|
|
|
args = parser.parse_args() |
|
|
|
dtype_mapping = { |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"bfloat16": torch.bfloat16, |
|
} |
|
|
|
args.torch_dtype = dtype_mapping[args.torch_dtype] |
|
|
|
model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args) |
|
|
|
start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv) |