File size: 8,354 Bytes
676b3ba
d33f11b
676b3ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import argparse
# from transformers import AutoTokenizer
import torch
import os
import numpy as np

import os

# Additional import for gradio
import gradio as gr
import open3d as o3d
import plotly.graph_objects as go
import time

import logging


def farthest_point_sample(point, npoint):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    point = point[centroids.astype(np.int32)]
    return point

def pc_norm(pc):
    """ pc: NxC, return NxC """
    xyz = pc[:, :3]
    other_feature = pc[:, 3:]

    centroid = np.mean(xyz, axis=0)
    xyz = xyz - centroid
    m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
    xyz = xyz / m

    pc = np.concatenate((xyz, other_feature), axis=1)
    return pc

def change_input_method(input_method):
    if input_method == 'File':
        result = [gr.update(visible=True),
        gr.update(visible=False)]
    elif input_method == 'Object ID':
        result = [gr.update(visible=False),
        gr.update(visible=True)]
    return result


def start_conversation(args):
    print("[INFO] Starting conversation...")
    logging.warning("Starting conversation...")
    while True:
        print("-" * 80)
        logging.warning("-" * 80)

        # Reset the conversation template
        # conv.reset()

        def confirm_point_cloud(point_cloud_input, answer_time):
            objects = None
            data = None
            # object_id_input = object_id_input.strip()

            print("%" * 80)
            logging.warning("%" * 80)


            file = point_cloud_input.name 
            print(f"Uploading file: {file}.")
            logging.warning(f"Uploading file: {file}.")
            print("%" * 80)
            logging.warning("%" * 80)

            manual_no_color = "no_color" in file

            try:
                if '.ply' in file:
                    pcd = o3d.io.read_point_cloud(file)
                    points = np.asarray(pcd.points)  # xyz
                    colors = np.asarray(pcd.colors)  # rgb, if available
                    # * if no colors actually, empty array
                    if colors.size == 0:
                        colors = None
                elif '.npy' in file:
                    data = np.load(file)
                    if data.shape[1] >= 3:
                        points = data[:, :3]
                    else:
                        raise ValueError("Input array has the wrong shape. Expected: [N, 3]. Got: {}.".format(data.shape))
                    colors = None if data.shape[1] < 6 else data[:, 3:6]
                else:
                    raise ValueError("Not supported data format.")
            # error
            except Exception as e:
                print(f"[ERROR] {e}")
                logging.warning(f"[ERROR] {e}")

                return None, None, answer_time, None

            if manual_no_color:
                colors = None

            if colors is not None:
                # * if colors in range(0-1)
                if np.max(colors) <= 1:
                    color_data = np.multiply(colors, 255).astype(int)  # Convert float values (0-1) to integers (0-255)
                # * if colors in range(0-255)
                elif np.max(colors) <= 255:
                    color_data = colors.astype(int)
            else:
                color_data = np.zeros_like(points).astype(int)  # Default to black color if RGB information is not available
            colors = color_data.astype(np.float32) / 255 # model input is (0-1)

            # Convert the RGB color data to a list of RGB strings in the format 'rgb(r, g, b)'
            color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]

            fig = go.Figure(
                data=[
                    go.Scatter3d(
                        x=points[:, 0], y=points[:, 1], z=points[:, 2],
                        mode='markers',
                        marker=dict(
                            size=1.2,
                            color=color_strings,  # Use the list of RGB strings for the marker colors
                        )
                    )
                ],
                layout=dict(
                    scene=dict(
                        xaxis=dict(visible=False),
                        yaxis=dict(visible=False),
                        zaxis=dict(visible=False)
                    ),
                    paper_bgcolor='rgb(255,255,255)'  # Set the background color to dark gray 50, 50, 50
                ),
            )

            points = np.concatenate((points, colors), axis=1)
            if 8192 < points.shape[0]:
                points = farthest_point_sample(points, 8192)
            point_clouds = pc_norm(points)
            point_clouds = torch.from_numpy(point_clouds).unsqueeze_(0).to(torch.float32)
            
            answer_time = 0
            
            return fig, answer_time, point_clouds

            
        with gr.Blocks() as demo:
            answer_time = gr.State(value=0)
            point_clouds = gr.State(value=None)
            # conv_state = gr.State(value=conv.copy())
            gr.Markdown(
                """
                # PointCloud Visualization 👀
                """
            )
            with gr.Row():
                with gr.Column():
                    point_cloud_input = gr.File(visible = True, label="Upload Point Cloud File (PLY, NPY)")
                    output = gr.Plot()
                    btn = gr.Button(value="Confirm Point Cloud")

                btn.click(confirm_point_cloud, inputs=[point_cloud_input, answer_time], outputs=[output, answer_time, point_clouds])
            # input_choice.change(change_input_method, input_choice, [point_cloud_input, object_id_input])
            # run_button.click(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1, answer_time, answer_time)

        demo.queue()
        demo.launch(server_port=args.port, share=True)    # server_port=7832, share=True
    
if __name__ == "__main__":
    # ! To release this demo in public, make sure to start in a place where no important data is stored.
    # ! Please check 1. the lanuch dir 2. the tmp dir (GRADIO_TEMP_DIR)
    # ! refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, \
         default="RunsenXu/PointLLM_7B_v1.2")


    parser.add_argument("--data_path", type=str, default="data/objaverse_data", required=False)
    parser.add_argument("--pointnum", type=int, default=8192)

    parser.add_argument("--log_file", type=str, default="serving_workdirs/serving_log.txt")
    parser.add_argument("--tmp_dir", type=str, default="serving_workdirs/tmp")

    # For gradio
    parser.add_argument("--port", type=int, default=7810)

    args = parser.parse_args()
    
    # * make serving dirs
    os.makedirs(os.path.dirname(args.log_file), exist_ok=True)
    os.makedirs(args.tmp_dir, exist_ok=True)
    
    # * add the current time for log name
    args.log_file = args.log_file.replace(".txt", f"_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.txt")

    logging.basicConfig(
        filename=args.log_file, 
        level=logging.WARNING, # * default gradio is info, so use warning
        format='%(asctime)s - %(message)s', 
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    logging.warning("-----New Run-----")
    logging.warning(f"args: {args}")

    print("-----New Run-----")
    print(f"[INFO] Args: {args}")

    # * set env variable GRADIO_TEMP_DIR to args.tmp_dir
    os.environ["GRADIO_TEMP_DIR"] = args.tmp_dir

    # model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
    start_conversation(args)