import gradio as gr import torch from modelovae import Node, GRASSEncoder, GRASSDecoder import networkx as nx import numpy as np import matplotlib.pyplot as plt from src.mesh_gen.modelador import GrafoCentros import trimesh as tm #import pymeshlab import plotly.graph_objects as go from torch.utils.data import Dataset, DataLoader import os from resamplear import * import open3d as o3d import pymeshlab as pm from vedo import * def count_fn(f): def wrapper(*args, **kwargs): wrapper.count += 1 return f(*args, **kwargs) wrapper.count = 0 return wrapper @count_fn def createNode(data, radius, left = None, right = None): """ Utility function to create a node. """ return Node(data, radius, left, right) def decode_testing(v, max, decoder, mult, min): def decode_node(v, max, decoder, mult, min, nbif): cl = decoder.nodeClassifier(v) _, label = torch.max(cl, 1) label = label.data if label == 1 and createNode.count <= max: #print("label 1") right, radius = decoder.internalDecoder(v) d = createNode(1, radius) d.right = decode_node(right, max, decoder, mult, min, nbif = nbif) return d elif label == 2 and createNode.count <= max: #print("label 2") left, right, radius = decoder.bifurcationDecoder(v) d = createNode(1, radius) d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1) d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1) return d elif label == 0 : ##output del classifier #print("label 0") if nbif >= 2: #print("mayor que min") radio = decoder.featureDecoder(v) return createNode(1,radio) else: #print("menor que min") left, right, radius = decoder.bifurcationDecoder(v) d = createNode(1, radius) d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1) d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1) return d ''' elif label == 0 : ##output del classifier print("0", createNode.count) radio = decoder.featureDecoder(v) return createNode(1,radio) ''' createNode.count = 0 dec = decode_node (v, max, decoder, mult, min, nbif = 0) return dec def numerar_nodos(root, count): if root is not None: numerar_nodos(root.left, count) root.data = len(count) count.append(1) numerar_nodos(root.right, count) return def traverse_xy(root, tree): if root is not None: traverse_xy(root.left, tree) print(root.radius.cpu().detach().numpy()) tree.append([root.radius.cpu().detach().numpy()[0][0], root.radius.cpu().detach().numpy()[0][1]]) traverse_xy(root.right, tree) return tree def tr(root): if root is not None: tr(root.left) root.radius[0][3]=root.radius[0][3]/20 if root.radius[0][3]<0: root.radius[0][3]=0 tr(root.right) return def tr2(root): if root is not None: tr2(root.left) root.radius[3]=root.radius[3]/10 if root.radius[3]<0: root.radius[3]=0 tr2(root.right) return use_gpu = True device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu") def traversefeatures(root, features): if root is not None: traversefeatures(root.left, features) features.append(root.radius.tolist()[0][3]) traversefeatures(root.right, features) return features def my_collate(batch): return batch def deserialize(data): if not data: return nodes = data.split(';') #print("node",nodes[3]) def post_order(nodes): if nodes[-1] == '#': nodes.pop() return None node = nodes.pop().split('_') data = int(node[0]) radius = node[1] rad = radius.split(",") rad [0] = rad[0].replace('[','') rad [3] = rad[3].replace(']','') r = [] for value in rad: r.append(float(value)) #r =[float(num) for num in radius if num.isdigit()] r = torch.tensor(r, device=device) #breakpoint() root = createNode(data, r) root.right = post_order(nodes) root.left = post_order(nodes) return root return post_order(nodes) def read_tree(filename, dir): with open('./' +dir +'/' +filename, "r") as f: byte = f.read() return byte class tDataset(Dataset): def __init__(self, l, dir, transform=None): self.names = l self.transform = transform self.data = [] #lista con las strings de todos los arboles for file in self.names: self.data.append(read_tree(file, dir)) self.trees = [] for tree in self.data: deserial = deserialize(tree) self.trees.append(deserial) def __len__(self): return len(self.names) def __getitem__(self, idx): tree = self.trees[idx] name = self.names[idx] return tree, name batch_size = 1 def prepararAristas( grafo ): for arista in grafo.edges(): nx.set_edge_attributes( grafo, {arista : {'procesada':False}}) t_list = os.listdir("./nuevostrees/" ) dataset = tDataset(t_list, "./nuevostrees/") data_loader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=my_collate) def predict(): success = False while not success: G = nx.Graph() z = torch.randn(1, latent_size) generated_images = decode_testing(z, 30, Grassdecoder, mult, 1) count = [] numerar_nodos(generated_images, count) #tr(generated_images) generated_images.toGraph( G, 0, True, flag = 0) r_list = [] r_list = traversefeatures(generated_images, r_list) max_radius = max(r_list) min_radius = min(r_list) ''' #generated_images = iter(data_loader).next()[0] generated_images = next(iter(data_loader))[0] filename = generated_images[1] print("filename", filename) bifurcation_nodes = [] leaf_nodes = [] generated_images[0].toGraph( G, 0, False, 0) generated_images = generated_images[0] original_mesh = load("mallas/"+filename.split("_")[0]+".obj") write(original_mesh, "original.obj") #generated_images.toGraph( graph, 0, False, flag = 0) #tr2(generated_images) ''' #nx.write_gpickle(graph, "grafo0.gpickle" ) for arista in G.edges(): nx.set_edge_attributes( G, {arista : {'procesada':False}}) #graph_resampled = resamplear(G, puntosPorUnidad=5) graph_resampled = G prepararAristas(graph_resampled) graphOfCenters = GrafoCentros(graph_resampled) if abs(max_radius/min_radius)<10 : try: graphOfCenters.tile() mesh = tm.Trimesh( graphOfCenters.getVertices(), graphOfCenters.getCaras() ) #mesh_o3d = mesh.as_open3d mesh_o3d = o3d.geometry.TriangleMesh() # Convert vertices and faces mesh_o3d.vertices = o3d.utility.Vector3dVector(mesh.vertices) mesh_o3d.triangles = o3d.utility.Vector3iVector(mesh.faces) m2 = mesh_o3d o3d.io.write_triangle_mesh("sinsub.obj", m2) mesh_o3d = mesh_o3d.subdivide_loop(3) o3d.io.write_triangle_mesh("output.obj", mesh_o3d) ms = pm.MeshSet() ms.load_new_mesh('output.obj') ms.compute_selection_by_self_intersections_per_face() if ms.current_mesh().selected_face_number() > 0: raise Exception("autointerseccion") v_matrix = np.asarray(mesh_o3d.vertices ) f_matrix = np.asarray(mesh_o3d.triangles ) v_matrix_2 = np.asarray(m2.vertices ) f_matrix_2 = np.asarray(m2.triangles ) success = True except Exception as e: print("No se pudo generar") print(str(e)) pass else: print("ratio", max_radius/min_radius) pass ''' M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2]))) x = (v_matrix[:,0] - min(v_matrix[:,0]))/M y = (v_matrix[:,1] - min(v_matrix[:,1]))/ M z = (v_matrix[:,2] - min(v_matrix[:,2]))/M ''' M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2]))) x = (v_matrix[:,0] - np.mean(v_matrix[:,0]))/M y = (v_matrix[:,1] - np.mean(v_matrix[:,1]))/ M z = (v_matrix[:,2] - np.mean(v_matrix[:,2]))/M minimo = min(min(x), min(y), min(z)) maximo = max(max(x), max(y), max(z)) fig = go.Figure(go.Mesh3d( x=x, y=y, z=z, i=f_matrix[:,0], j=f_matrix[:,1], k=f_matrix[:,2], color='red')) fig.update_layout( scene = dict( xaxis_range=[minimo, maximo], yaxis_range=[minimo, maximo], zaxis_range=[minimo, maximo], aspectratio=dict(x = 1, y = 1, z = 1), xaxis = dict(visible=False), yaxis = dict(visible=False), zaxis =dict(visible=False), )) fig2 = go.Figure(go.Mesh3d( x=v_matrix_2[:,0], y=v_matrix_2[:,1], z=v_matrix_2[:,2], i=f_matrix_2[:,0], j=f_matrix_2[:,1], k=f_matrix_2[:,2], color='red')) #o3d.io.write_triangle_mesh("generadas/"+filename.split("_")[0]+".obj", mesh_o3d) return fig a = [1.,1.,1.] mult = torch.Tensor(a) latent_size = 32 Grassdecoder = GRASSDecoder(latent_size=latent_size, hidden_size=256, mult = mult) Grassdecoder = Grassdecoder Grassdecoder.eval() checkpoint = torch.load("64bueno.pth", map_location=torch.device('cpu')) Grassdecoder.load_state_dict(checkpoint['decoder_state_dict']) gr.Interface( predict, inputs=None, outputs=gr.Plot(), ).launch()