vv / app.py
paufeldman's picture
fix as_open3d deprecated
d883ba3
import gradio as gr
import torch
from modelovae import Node, GRASSEncoder, GRASSDecoder
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from src.mesh_gen.modelador import GrafoCentros
import trimesh as tm
#import pymeshlab
import plotly.graph_objects as go
from torch.utils.data import Dataset, DataLoader
import os
from resamplear import *
import open3d as o3d
import pymeshlab as pm
from vedo import *
def count_fn(f):
def wrapper(*args, **kwargs):
wrapper.count += 1
return f(*args, **kwargs)
wrapper.count = 0
return wrapper
@count_fn
def createNode(data, radius, left = None, right = None):
"""
Utility function to create a node.
"""
return Node(data, radius, left, right)
def decode_testing(v, max, decoder, mult, min):
def decode_node(v, max, decoder, mult, min, nbif):
cl = decoder.nodeClassifier(v)
_, label = torch.max(cl, 1)
label = label.data
if label == 1 and createNode.count <= max:
#print("label 1")
right, radius = decoder.internalDecoder(v)
d = createNode(1, radius)
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif)
return d
elif label == 2 and createNode.count <= max:
#print("label 2")
left, right, radius = decoder.bifurcationDecoder(v)
d = createNode(1, radius)
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1)
d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1)
return d
elif label == 0 : ##output del classifier
#print("label 0")
if nbif >= 2:
#print("mayor que min")
radio = decoder.featureDecoder(v)
return createNode(1,radio)
else:
#print("menor que min")
left, right, radius = decoder.bifurcationDecoder(v)
d = createNode(1, radius)
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1)
d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1)
return d
'''
elif label == 0 : ##output del classifier
print("0", createNode.count)
radio = decoder.featureDecoder(v)
return createNode(1,radio)
'''
createNode.count = 0
dec = decode_node (v, max, decoder, mult, min, nbif = 0)
return dec
def numerar_nodos(root, count):
if root is not None:
numerar_nodos(root.left, count)
root.data = len(count)
count.append(1)
numerar_nodos(root.right, count)
return
def traverse_xy(root, tree):
if root is not None:
traverse_xy(root.left, tree)
print(root.radius.cpu().detach().numpy())
tree.append([root.radius.cpu().detach().numpy()[0][0], root.radius.cpu().detach().numpy()[0][1]])
traverse_xy(root.right, tree)
return tree
def tr(root):
if root is not None:
tr(root.left)
root.radius[0][3]=root.radius[0][3]/20
if root.radius[0][3]<0:
root.radius[0][3]=0
tr(root.right)
return
def tr2(root):
if root is not None:
tr2(root.left)
root.radius[3]=root.radius[3]/10
if root.radius[3]<0:
root.radius[3]=0
tr2(root.right)
return
use_gpu = True
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
def traversefeatures(root, features):
if root is not None:
traversefeatures(root.left, features)
features.append(root.radius.tolist()[0][3])
traversefeatures(root.right, features)
return features
def my_collate(batch):
return batch
def deserialize(data):
if not data:
return
nodes = data.split(';')
#print("node",nodes[3])
def post_order(nodes):
if nodes[-1] == '#':
nodes.pop()
return None
node = nodes.pop().split('_')
data = int(node[0])
radius = node[1]
rad = radius.split(",")
rad [0] = rad[0].replace('[','')
rad [3] = rad[3].replace(']','')
r = []
for value in rad:
r.append(float(value))
#r =[float(num) for num in radius if num.isdigit()]
r = torch.tensor(r, device=device)
#breakpoint()
root = createNode(data, r)
root.right = post_order(nodes)
root.left = post_order(nodes)
return root
return post_order(nodes)
def read_tree(filename, dir):
with open('./' +dir +'/' +filename, "r") as f:
byte = f.read()
return byte
class tDataset(Dataset):
def __init__(self, l, dir, transform=None):
self.names = l
self.transform = transform
self.data = [] #lista con las strings de todos los arboles
for file in self.names:
self.data.append(read_tree(file, dir))
self.trees = []
for tree in self.data:
deserial = deserialize(tree)
self.trees.append(deserial)
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
tree = self.trees[idx]
name = self.names[idx]
return tree, name
batch_size = 1
def prepararAristas( grafo ):
for arista in grafo.edges():
nx.set_edge_attributes( grafo, {arista : {'procesada':False}})
t_list = os.listdir("./nuevostrees/" )
dataset = tDataset(t_list, "./nuevostrees/")
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=my_collate)
def predict():
success = False
while not success:
G = nx.Graph()
z = torch.randn(1, latent_size)
generated_images = decode_testing(z, 30, Grassdecoder, mult, 1)
count = []
numerar_nodos(generated_images, count)
#tr(generated_images)
generated_images.toGraph( G, 0, True, flag = 0)
r_list = []
r_list = traversefeatures(generated_images, r_list)
max_radius = max(r_list)
min_radius = min(r_list)
'''
#generated_images = iter(data_loader).next()[0]
generated_images = next(iter(data_loader))[0]
filename = generated_images[1]
print("filename", filename)
bifurcation_nodes = []
leaf_nodes = []
generated_images[0].toGraph( G, 0, False, 0)
generated_images = generated_images[0]
original_mesh = load("mallas/"+filename.split("_")[0]+".obj")
write(original_mesh, "original.obj")
#generated_images.toGraph( graph, 0, False, flag = 0)
#tr2(generated_images)
'''
#nx.write_gpickle(graph, "grafo0.gpickle" )
for arista in G.edges():
nx.set_edge_attributes( G, {arista : {'procesada':False}})
#graph_resampled = resamplear(G, puntosPorUnidad=5)
graph_resampled = G
prepararAristas(graph_resampled)
graphOfCenters = GrafoCentros(graph_resampled)
if abs(max_radius/min_radius)<10 :
try:
graphOfCenters.tile()
mesh = tm.Trimesh( graphOfCenters.getVertices(), graphOfCenters.getCaras() )
#mesh_o3d = mesh.as_open3d
mesh_o3d = o3d.geometry.TriangleMesh()
# Convert vertices and faces
mesh_o3d.vertices = o3d.utility.Vector3dVector(mesh.vertices)
mesh_o3d.triangles = o3d.utility.Vector3iVector(mesh.faces)
m2 = mesh_o3d
o3d.io.write_triangle_mesh("sinsub.obj", m2)
mesh_o3d = mesh_o3d.subdivide_loop(3)
o3d.io.write_triangle_mesh("output.obj", mesh_o3d)
ms = pm.MeshSet()
ms.load_new_mesh('output.obj')
ms.compute_selection_by_self_intersections_per_face()
if ms.current_mesh().selected_face_number() > 0:
raise Exception("autointerseccion")
v_matrix = np.asarray(mesh_o3d.vertices )
f_matrix = np.asarray(mesh_o3d.triangles )
v_matrix_2 = np.asarray(m2.vertices )
f_matrix_2 = np.asarray(m2.triangles )
success = True
except Exception as e:
print("No se pudo generar")
print(str(e))
pass
else:
print("ratio", max_radius/min_radius)
pass
'''
M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2])))
x = (v_matrix[:,0] - min(v_matrix[:,0]))/M
y = (v_matrix[:,1] - min(v_matrix[:,1]))/ M
z = (v_matrix[:,2] - min(v_matrix[:,2]))/M
'''
M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2])))
x = (v_matrix[:,0] - np.mean(v_matrix[:,0]))/M
y = (v_matrix[:,1] - np.mean(v_matrix[:,1]))/ M
z = (v_matrix[:,2] - np.mean(v_matrix[:,2]))/M
minimo = min(min(x), min(y), min(z))
maximo = max(max(x), max(y), max(z))
fig = go.Figure(go.Mesh3d(
x=x, y=y, z=z,
i=f_matrix[:,0], j=f_matrix[:,1], k=f_matrix[:,2],
color='red'))
fig.update_layout(
scene = dict(
xaxis_range=[minimo, maximo],
yaxis_range=[minimo, maximo],
zaxis_range=[minimo, maximo],
aspectratio=dict(x = 1, y = 1, z = 1),
xaxis = dict(visible=False),
yaxis = dict(visible=False),
zaxis =dict(visible=False),
))
fig2 = go.Figure(go.Mesh3d(
x=v_matrix_2[:,0], y=v_matrix_2[:,1], z=v_matrix_2[:,2],
i=f_matrix_2[:,0], j=f_matrix_2[:,1], k=f_matrix_2[:,2],
color='red'))
#o3d.io.write_triangle_mesh("generadas/"+filename.split("_")[0]+".obj", mesh_o3d)
return fig
a = [1.,1.,1.]
mult = torch.Tensor(a)
latent_size = 32
Grassdecoder = GRASSDecoder(latent_size=latent_size, hidden_size=256, mult = mult)
Grassdecoder = Grassdecoder
Grassdecoder.eval()
checkpoint = torch.load("64bueno.pth", map_location=torch.device('cpu'))
Grassdecoder.load_state_dict(checkpoint['decoder_state_dict'])
gr.Interface( predict,
inputs=None,
outputs=gr.Plot(),
).launch()