vv / app.py
paufeldman's picture
anda con html en hf
a0f9147
import gradio as gr
import torch
from modelovae import Node, GRASSEncoder, GRASSDecoder
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from src.mesh_gen.modelador import GrafoCentros
import trimesh as tm
#import pymeshlab
import plotly.graph_objects as go
from torch.utils.data import Dataset, DataLoader
import os
from resamplear import *
import open3d as o3d
import pymeshlab as pm
#from vedo import *
def count_fn(f):
def wrapper(*args, **kwargs):
wrapper.count += 1
return f(*args, **kwargs)
wrapper.count = 0
return wrapper
@count_fn
def createNode(data, radius, left = None, right = None):
"""
Utility function to create a node.
"""
return Node(data, radius, left, right)
def decode_testing(v, max, decoder, mult, min):
def decode_node(v, max, decoder, mult, min, nbif):
cl = decoder.nodeClassifier(v)
_, label = torch.max(cl, 1)
label = label.data
if label == 1 and createNode.count <= max:
#print("label 1")
right, radius = decoder.internalDecoder(v)
d = createNode(1, radius)
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif)
return d
elif label == 2 and createNode.count <= max:
#print("label 2")
left, right, radius = decoder.bifurcationDecoder(v)
d = createNode(1, radius)
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1)
d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1)
return d
elif label == 0 : ##output del classifier
#print("label 0")
if nbif >= 2:
#print("mayor que min")
radio = decoder.featureDecoder(v)
return createNode(1,radio)
else:
#print("menor que min")
left, right, radius = decoder.bifurcationDecoder(v)
d = createNode(1, radius)
d.right = decode_node(right, max, decoder, mult, min, nbif = nbif+1)
d.left = decode_node(left, max, decoder, mult, min, nbif = nbif+1)
return d
'''
elif label == 0 : ##output del classifier
print("0", createNode.count)
radio = decoder.featureDecoder(v)
return createNode(1,radio)
'''
createNode.count = 0
dec = decode_node (v, max, decoder, mult, min, nbif = 0)
return dec
def numerar_nodos(root, count):
if root is not None:
numerar_nodos(root.left, count)
root.data = len(count)
count.append(1)
numerar_nodos(root.right, count)
return
def traverse_xy(root, tree):
if root is not None:
traverse_xy(root.left, tree)
print(root.radius.cpu().detach().numpy())
tree.append([root.radius.cpu().detach().numpy()[0][0], root.radius.cpu().detach().numpy()[0][1]])
traverse_xy(root.right, tree)
return tree
def tr(root):
if root is not None:
tr(root.left)
root.radius[0][3]=root.radius[0][3]/20
if root.radius[0][3]<0:
root.radius[0][3]=0
tr(root.right)
return
def tr2(root):
if root is not None:
tr2(root.left)
root.radius[3]=root.radius[3]/10
if root.radius[3]<0:
root.radius[3]=0
tr2(root.right)
return
use_gpu = True
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
def traversefeatures(root, features):
if root is not None:
traversefeatures(root.left, features)
features.append(root.radius.tolist()[0][3])
traversefeatures(root.right, features)
return features
def my_collate(batch):
return batch
def deserialize(data):
if not data:
return
nodes = data.split(';')
#print("node",nodes[3])
def post_order(nodes):
if nodes[-1] == '#':
nodes.pop()
return None
node = nodes.pop().split('_')
data = int(node[0])
radius = node[1]
rad = radius.split(",")
rad [0] = rad[0].replace('[','')
rad [3] = rad[3].replace(']','')
r = []
for value in rad:
r.append(float(value))
#r =[float(num) for num in radius if num.isdigit()]
r = torch.tensor(r, device=device)
#breakpoint()
root = createNode(data, r)
root.right = post_order(nodes)
root.left = post_order(nodes)
return root
return post_order(nodes)
def read_tree(filename, dir):
with open('./' +dir +'/' +filename, "r") as f:
byte = f.read()
return byte
class tDataset(Dataset):
def __init__(self, l, dir, transform=None):
self.names = l
self.transform = transform
self.data = [] #lista con las strings de todos los arboles
for file in self.names:
self.data.append(read_tree(file, dir))
self.trees = []
for tree in self.data:
deserial = deserialize(tree)
self.trees.append(deserial)
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
tree = self.trees[idx]
name = self.names[idx]
return tree, name
batch_size = 1
def prepararAristas( grafo ):
for arista in grafo.edges():
nx.set_edge_attributes( grafo, {arista : {'procesada':False}})
t_list = os.listdir("./nuevostrees/" )
dataset = tDataset(t_list, "./nuevostrees/")
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=my_collate)
import plotly.io as pio
def predict():
success = False
while not success:
G = nx.Graph()
z = torch.randn(1, latent_size)
generated_images = decode_testing(z, 30, Grassdecoder, mult, 1)
count = []
numerar_nodos(generated_images, count)
#tr(generated_images)
generated_images.toGraph( G, 0, True, flag = 0)
r_list = []
r_list = traversefeatures(generated_images, r_list)
max_radius = max(r_list)
min_radius = min(r_list)
'''
#generated_images = iter(data_loader).next()[0]
generated_images = next(iter(data_loader))[0]
filename = generated_images[1]
print("filename", filename)
bifurcation_nodes = []
leaf_nodes = []
generated_images[0].toGraph( G, 0, False, 0)
generated_images = generated_images[0]
original_mesh = load("mallas/"+filename.split("_")[0]+".obj")
write(original_mesh, "original.obj")
#generated_images.toGraph( graph, 0, False, flag = 0)
#tr2(generated_images)
'''
#nx.write_gpickle(graph, "grafo0.gpickle" )
for arista in G.edges():
nx.set_edge_attributes( G, {arista : {'procesada':False}})
#graph_resampled = resamplear(G, puntosPorUnidad=5)
graph_resampled = G
prepararAristas(graph_resampled)
graphOfCenters = GrafoCentros(graph_resampled)
if abs(max_radius/min_radius)<10 :
try:
graphOfCenters.tile()
mesh = tm.Trimesh( graphOfCenters.getVertices(), graphOfCenters.getCaras() )
#mesh_o3d = mesh.as_open3d
mesh_o3d = o3d.geometry.TriangleMesh()
# Convert vertices and faces
mesh_o3d.vertices = o3d.utility.Vector3dVector(mesh.vertices)
mesh_o3d.triangles = o3d.utility.Vector3iVector(mesh.faces)
m2 = mesh_o3d
o3d.io.write_triangle_mesh("sinsub.obj", m2)
mesh_o3d = mesh_o3d.subdivide_loop(3)
o3d.io.write_triangle_mesh("output.obj", mesh_o3d)
ms = pm.MeshSet()
ms.load_new_mesh('output.obj')
ms.compute_selection_by_self_intersections_per_face()
if ms.current_mesh().selected_face_number() > 0:
raise Exception("autointerseccion")
v_matrix = np.asarray(mesh_o3d.vertices )
f_matrix = np.asarray(mesh_o3d.triangles )
v_matrix_2 = np.asarray(m2.vertices )
f_matrix_2 = np.asarray(m2.triangles )
success = True
except Exception as e:
print("No se pudo generar")
print(str(e))
pass
else:
print("ratio", max_radius/min_radius)
pass
'''
M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2])))
x = (v_matrix[:,0] - min(v_matrix[:,0]))/M
y = (v_matrix[:,1] - min(v_matrix[:,1]))/ M
z = (v_matrix[:,2] - min(v_matrix[:,2]))/M
'''
M = max((max(v_matrix[:,0]) - min(v_matrix[:,0])), (max(v_matrix[:,1]) - min(v_matrix[:,1])), (max(v_matrix[:,2]) - min(v_matrix[:,2])))
x = (v_matrix[:,0] - np.mean(v_matrix[:,0]))/M
y = (v_matrix[:,1] - np.mean(v_matrix[:,1]))/ M
z = (v_matrix[:,2] - np.mean(v_matrix[:,2]))/M
minimo = min(min(x), min(y), min(z))
maximo = max(max(x), max(y), max(z))
fig = go.Figure(go.Mesh3d(
x=x, y=y, z=z,
i=f_matrix[:,0], j=f_matrix[:,1], k=f_matrix[:,2],
color='red'))
fig.update_layout(
scene = dict(
xaxis_range=[minimo, maximo],
yaxis_range=[minimo, maximo],
zaxis_range=[minimo, maximo],
aspectratio=dict(x = 1, y = 1, z = 1),
xaxis = dict(visible=False),
yaxis = dict(visible=False),
zaxis =dict(visible=False),
))
fig2 = go.Figure(go.Mesh3d(
x=v_matrix_2[:,0], y=v_matrix_2[:,1], z=v_matrix_2[:,2],
i=f_matrix_2[:,0], j=f_matrix_2[:,1], k=f_matrix_2[:,2],
color='red'))
#o3d.io.write_triangle_mesh("generadas/"+filename.split("_")[0]+".obj", mesh_o3d)
html = pio.to_html(fig, full_html=True, include_plotlyjs='cdn')
iframe_html = f"""<iframe width="100%" height="600" frameborder="0" srcdoc="{html.replace('"', '&quot;')}"></iframe>"""
return iframe_html #fig
a = [1.,1.,1.]
mult = torch.Tensor(a)
latent_size = 32
Grassdecoder = GRASSDecoder(latent_size=latent_size, hidden_size=256, mult = mult)
Grassdecoder = Grassdecoder
Grassdecoder.eval()
checkpoint = torch.load("64bueno.pth", map_location=torch.device('cpu'), weights_only=True)
Grassdecoder.load_state_dict(checkpoint['decoder_state_dict'])
gr.Interface( predict,
inputs=None,
outputs=gr.HTML(label="Embedded Plot (HF-safe)"),
).launch() #gr.Plot()
'''gr.Interface( predict,
inputs=None,
outputs=[gr.Plot(label="Plotly Plot (local only)"), gr.HTML(label="Embedded Plot (HF-safe)")],
).launch() #gr.Plot()'''