Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from modelovae import Node, GRASSEncoder, GRASSDecoder | |
import networkx as nx | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from src.mesh_gen.modelador import GrafoCentros | |
import trimesh as tm | |
#import pymeshlab | |
import plotly.graph_objects as go | |
from torch.utils.data import Dataset, DataLoader | |
import os | |
from resamplear import * | |
import open3d as o3d | |
import pymeshlab as pm | |
from vedo import * | |
def count_fn(f): | |
def wrapper(*args, **kwargs): | |
wrapper.count += 1 | |
return f(*args, **kwargs) | |
wrapper.count = 0 | |
return wrapper | |
def createNode(data, radius, left = None, right = None): | |
""" | |
Utility function to create a node. | |
""" | |
return Node(data, radius, left, right) | |
def decode_testing(v, max, decoder, mult, min): | |
def decode_node(v, max, decoder, mult, min, nbif): | |
cl = decoder.nodeClassifier(v) | |
_, label = torch.max(cl, 1) | |
label = label.data | |
if label == 1 and createNode.count <= max: | |
#print("label 1") | |
right, radius = decoder.internalDecoder(v) | |
d = createNode(1, radius) | |
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif) | |
return d | |
elif label == 2 and createNode.count <= max: | |
#print("label 2") | |
left, right, radius = decoder.bifurcationDecoder(v) | |
d = createNode(1, radius) | |
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1) | |
d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1) | |
return d | |
elif label == 0 : ##output del classifier | |
#print("label 0") | |
if nbif >= 2: | |
#print("mayor que min") | |
radio = decoder.featureDecoder(v) | |
return createNode(1,radio) | |
else: | |
#print("menor que min") | |
left, right, radius = decoder.bifurcationDecoder(v) | |
d = createNode(1, radius) | |
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1) | |
d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1) | |
return d | |
''' | |
elif label == 0 : ##output del classifier | |
print("0", createNode.count) | |
radio = decoder.featureDecoder(v) | |
return createNode(1,radio) | |
''' | |
createNode.count = 0 | |
dec = decode_node (v, max, decoder, mult, min, nbif = 0) | |
return dec | |
def numerar_nodos(root, count): | |
if root is not None: | |
numerar_nodos(root.left, count) | |
root.data = len(count) | |
count.append(1) | |
numerar_nodos(root.right, count) | |
return | |
def traverse_xy(root, tree): | |
if root is not None: | |
traverse_xy(root.left, tree) | |
print(root.radius.cpu().detach().numpy()) | |
tree.append([root.radius.cpu().detach().numpy()[0][0], root.radius.cpu().detach().numpy()[0][1]]) | |
traverse_xy(root.right, tree) | |
return tree | |
def tr(root): | |
if root is not None: | |
tr(root.left) | |
root.radius[0][3]=root.radius[0][3]/20 | |
if root.radius[0][3]<0: | |
root.radius[0][3]=0 | |
tr(root.right) | |
return | |
def tr2(root): | |
if root is not None: | |
tr2(root.left) | |
root.radius[3]=root.radius[3]/10 | |
if root.radius[3]<0: | |
root.radius[3]=0 | |
tr2(root.right) | |
return | |
use_gpu = True | |
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu") | |
def traversefeatures(root, features): | |
if root is not None: | |
traversefeatures(root.left, features) | |
features.append(root.radius.tolist()[0][3]) | |
traversefeatures(root.right, features) | |
return features | |
def my_collate(batch): | |
return batch | |
def deserialize(data): | |
if not data: | |
return | |
nodes = data.split(';') | |
#print("node",nodes[3]) | |
def post_order(nodes): | |
if nodes[-1] == '#': | |
nodes.pop() | |
return None | |
node = nodes.pop().split('_') | |
data = int(node[0]) | |
radius = node[1] | |
rad = radius.split(",") | |
rad [0] = rad[0].replace('[','') | |
rad [3] = rad[3].replace(']','') | |
r = [] | |
for value in rad: | |
r.append(float(value)) | |
#r =[float(num) for num in radius if num.isdigit()] | |
r = torch.tensor(r, device=device) | |
#breakpoint() | |
root = createNode(data, r) | |
root.right = post_order(nodes) | |
root.left = post_order(nodes) | |
return root | |
return post_order(nodes) | |
def read_tree(filename, dir): | |
with open('./' +dir +'/' +filename, "r") as f: | |
byte = f.read() | |
return byte | |
class tDataset(Dataset): | |
def __init__(self, l, dir, transform=None): | |
self.names = l | |
self.transform = transform | |
self.data = [] #lista con las strings de todos los arboles | |
for file in self.names: | |
self.data.append(read_tree(file, dir)) | |
self.trees = [] | |
for tree in self.data: | |
deserial = deserialize(tree) | |
self.trees.append(deserial) | |
def __len__(self): | |
return len(self.names) | |
def __getitem__(self, idx): | |
tree = self.trees[idx] | |
name = self.names[idx] | |
return tree, name | |
batch_size = 1 | |
def prepararAristas( grafo ): | |
for arista in grafo.edges(): | |
nx.set_edge_attributes( grafo, {arista : {'procesada':False}}) | |
t_list = os.listdir("./nuevostrees/" ) | |
dataset = tDataset(t_list, "./nuevostrees/") | |
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=my_collate) | |
def predict(): | |
success = False | |
while not success: | |
G = nx.Graph() | |
z = torch.randn(1, latent_size) | |
generated_images = decode_testing(z, 30, Grassdecoder, mult, 1) | |
count = [] | |
numerar_nodos(generated_images, count) | |
#tr(generated_images) | |
generated_images.toGraph( G, 0, True, flag = 0) | |
r_list = [] | |
r_list = traversefeatures(generated_images, r_list) | |
max_radius = max(r_list) | |
min_radius = min(r_list) | |
''' | |
#generated_images = iter(data_loader).next()[0] | |
generated_images = next(iter(data_loader))[0] | |
filename = generated_images[1] | |
print("filename", filename) | |
bifurcation_nodes = [] | |
leaf_nodes = [] | |
generated_images[0].toGraph( G, 0, False, 0) | |
generated_images = generated_images[0] | |
original_mesh = load("mallas/"+filename.split("_")[0]+".obj") | |
write(original_mesh, "original.obj") | |
#generated_images.toGraph( graph, 0, False, flag = 0) | |
#tr2(generated_images) | |
''' | |
#nx.write_gpickle(graph, "grafo0.gpickle" ) | |
for arista in G.edges(): | |
nx.set_edge_attributes( G, {arista : {'procesada':False}}) | |
#graph_resampled = resamplear(G, puntosPorUnidad=5) | |
graph_resampled = G | |
prepararAristas(graph_resampled) | |
graphOfCenters = GrafoCentros(graph_resampled) | |
if abs(max_radius/min_radius)<10 : | |
try: | |
graphOfCenters.tile() | |
mesh = tm.Trimesh( graphOfCenters.getVertices(), graphOfCenters.getCaras() ) | |
#mesh_o3d = mesh.as_open3d | |
mesh_o3d = o3d.geometry.TriangleMesh() | |
# Convert vertices and faces | |
mesh_o3d.vertices = o3d.utility.Vector3dVector(mesh.vertices) | |
mesh_o3d.triangles = o3d.utility.Vector3iVector(mesh.faces) | |
m2 = mesh_o3d | |
o3d.io.write_triangle_mesh("sinsub.obj", m2) | |
mesh_o3d = mesh_o3d.subdivide_loop(3) | |
o3d.io.write_triangle_mesh("output.obj", mesh_o3d) | |
ms = pm.MeshSet() | |
ms.load_new_mesh('output.obj') | |
ms.compute_selection_by_self_intersections_per_face() | |
if ms.current_mesh().selected_face_number() > 0: | |
raise Exception("autointerseccion") | |
v_matrix = np.asarray(mesh_o3d.vertices ) | |
f_matrix = np.asarray(mesh_o3d.triangles ) | |
v_matrix_2 = np.asarray(m2.vertices ) | |
f_matrix_2 = np.asarray(m2.triangles ) | |
success = True | |
except Exception as e: | |
print("No se pudo generar") | |
print(str(e)) | |
pass | |
else: | |
print("ratio", max_radius/min_radius) | |
pass | |
''' | |
M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2]))) | |
x = (v_matrix[:,0] - min(v_matrix[:,0]))/M | |
y = (v_matrix[:,1] - min(v_matrix[:,1]))/ M | |
z = (v_matrix[:,2] - min(v_matrix[:,2]))/M | |
''' | |
M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2]))) | |
x = (v_matrix[:,0] - np.mean(v_matrix[:,0]))/M | |
y = (v_matrix[:,1] - np.mean(v_matrix[:,1]))/ M | |
z = (v_matrix[:,2] - np.mean(v_matrix[:,2]))/M | |
minimo = min(min(x), min(y), min(z)) | |
maximo = max(max(x), max(y), max(z)) | |
fig = go.Figure(go.Mesh3d( | |
x=x, y=y, z=z, | |
i=f_matrix[:,0], j=f_matrix[:,1], k=f_matrix[:,2], | |
color='red')) | |
fig.update_layout( | |
scene = dict( | |
xaxis_range=[minimo, maximo], | |
yaxis_range=[minimo, maximo], | |
zaxis_range=[minimo, maximo], | |
aspectratio=dict(x = 1, y = 1, z = 1), | |
xaxis = dict(visible=False), | |
yaxis = dict(visible=False), | |
zaxis =dict(visible=False), | |
)) | |
fig2 = go.Figure(go.Mesh3d( | |
x=v_matrix_2[:,0], y=v_matrix_2[:,1], z=v_matrix_2[:,2], | |
i=f_matrix_2[:,0], j=f_matrix_2[:,1], k=f_matrix_2[:,2], | |
color='red')) | |
#o3d.io.write_triangle_mesh("generadas/"+filename.split("_")[0]+".obj", mesh_o3d) | |
return fig | |
a = [1.,1.,1.] | |
mult = torch.Tensor(a) | |
latent_size = 32 | |
Grassdecoder = GRASSDecoder(latent_size=latent_size, hidden_size=256, mult = mult) | |
Grassdecoder = Grassdecoder | |
Grassdecoder.eval() | |
checkpoint = torch.load("64bueno.pth", map_location=torch.device('cpu')) | |
Grassdecoder.load_state_dict(checkpoint['decoder_state_dict']) | |
gr.Interface( predict, | |
inputs=None, | |
outputs=gr.Plot(), | |
).launch() |