File size: 4,484 Bytes
5ea1b6f
5eb7f8a
 
5ea1b6f
5eb7f8a
ede25fc
2e0c5aa
5ea1b6f
2e121c3
2e0c5aa
 
ede25fc
5ea1b6f
 
 
ede25fc
2e0c5aa
5eb7f8a
 
ede25fc
fa53b56
ede25fc
fa53b56
 
 
ede25fc
 
878eb5c
fa53b56
878eb5c
ede25fc
 
525ce44
3e4a0db
525ce44
5ea1b6f
 
5eb7f8a
 
 
 
2e0c5aa
 
 
 
 
 
 
 
 
 
 
d9403e1
 
3aabd0b
d9403e1
 
 
 
27ec183
e7895e8
 
 
 
65555e4
2e0c5aa
 
 
 
65555e4
 
 
2e0c5aa
 
65555e4
 
2e0c5aa
f369ed3
2495238
3aabd0b
65555e4
2e0c5aa
65555e4
2e0c5aa
1036858
5ea1b6f
 
764e17a
c257e9e
764e17a
878eb5c
 
 
 
525ce44
878eb5c
 
 
525ce44
878eb5c
 
 
 
 
 
 
 
 
5ea1b6f
2e0c5aa
5ea1b6f
 
 
 
 
878eb5c
 
 
 
 
 
 
 
 
5ea1b6f
 
ede25fc
5ea1b6f
ede25fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import pickle
from datasets import load_dataset
from plaid.containers.sample import Sample


import numpy as np
import pyrender
from trimesh import Trimesh
import matplotlib as mpl
import matplotlib.cm as cm

import os
# switch to "osmesa" or "egl" before loading pyrender
os.environ["PYOPENGL_PLATFORM"] = "egl"


# os.system("wget https://zenodo.org/records/10124594/files/Tensile2d.tar.gz")
# os.system("tar -xvf Tensile2d.tar.gz")

hf_dataset = load_dataset("PLAID-datasets/AirfRANS_remeshed", split="all_samples")

nb_samples = 1000

field_names_train = ["Ux", "Uy", "p", "nut", "implicit_distance"]


_HEADER_ = '''
<h2><b>Visualization demo of <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2>
'''


def round_num(num)->str:
    return '%s' % float('%.3g' % num)

def sample_info(sample_id_str, fieldn):

    sample_ = hf_dataset[int(sample_id_str)]["sample"]
    plaid_sample = Sample.model_validate(pickle.loads(sample_))
    # plaid_sample = Sample.load_from_dir(f"Tensile2d/dataset/samples/sample_"+str(sample_id_str).zfill(9))

    nodes = plaid_sample.get_nodes()
    field = plaid_sample.get_field(fieldn)
    if nodes.shape[1] == 2:
        nodes__ = np.zeros((nodes.shape[0],nodes.shape[1]+1))
        nodes__[:,:-1] = nodes
        nodes = nodes__


    triangles = plaid_sample.get_elements()['TRI_3']

    # generate colormap
    if np.linalg.norm(field) > 0:
        norm = mpl.colors.Normalize(vmin=np.min(field), vmax=np.max(field))
        cmap = cm.seismic#cm.coolwarm
        m = cm.ScalarMappable(norm=norm, cmap=cmap)
    
        vertex_colors = m.to_rgba(field)[:,:3]
    else:
        vertex_colors = 1+np.zeros((field.shape[0], 3))
        vertex_colors[:,0] = 0.2298057
        vertex_colors[:,1] = 0.01555616
        vertex_colors[:,2] = 0.15023281

    # generate mesh
    trimesh = Trimesh(vertices = nodes, faces = triangles)
    trimesh.visual.vertex_colors = vertex_colors
    mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False)

    # compose scene
    scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0])
    camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0)
    light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.)

    scene.add(mesh, pose=  np.eye(4))
    scene.add(light, pose=  np.eye(4))

    scene.add(camera, pose=[[ 1,  0,  0,  1],
                            [ 0,  1,  0,  0],
                            [ 0,  0,  1,  6],
                            [ 0,  0,  0,  1]])

    # render scene
    r = pyrender.OffscreenRenderer(1024, 1024)
    color, _ = r.render(scene)
    

    
    str__ = f"Training sample {sample_id_str}\n"
    str__ += str(plaid_sample)+"\n"
    
    if len(hf_dataset.description['in_scalars_names'])>0:        
        str__ += "\ninput scalars:\n"
        for sname in hf_dataset.description['in_scalars_names']:
            str__ += f"- {sname}: {round_num(plaid_sample.get_scalar(sname))}\n"
    if len(hf_dataset.description['out_scalars_names'])>0:        
        str__ += "\noutput scalars:\n"
        for sname in hf_dataset.description['out_scalars_names']:
            str__ += f"- {sname}: {round_num(plaid_sample.get_scalar(sname))}\n"
    str__ += f"\n\nMesh number of nodes: {nodes.shape[0]}\n"
    if len(hf_dataset.description['in_fields_names'])>0:        
        str__ += "\ninput fields:\n"
        for fname in hf_dataset.description['in_fields_names']:
            str__ += f"- {fname}\n"
    if len(hf_dataset.description['out_fields_names'])>0:        
        str__ += "\noutput fields:\n"
        for fname in hf_dataset.description['out_fields_names']:
            str__ += f"- {fname}\n"

    return str__, color


if __name__ == "__main__":

    with gr.Blocks() as demo:
        gr.Markdown(_HEADER_)
        with gr.Row(variant="panel"):
            with gr.Column():
                d1 = gr.Slider(0, nb_samples-1, value=0, label="Training sample id", info="Choose between 0 and "+str(nb_samples-1))
                output1 = gr.Text(label="Training sample info")
            with gr.Column():
                d2 = gr.Dropdown(field_names_train, value=field_names_train[0], label="Field name")        
                output2 = gr.Image(label="Training sample visualization")
                
        d1.input(sample_info, [d1, d2], [output1, output2])
        d2.input(sample_info, [d1, d2], [output1, output2]) 

    demo.launch()