LuyangZ's picture
Update app.py
2e8f2bb verified
raw
history blame contribute delete
No virus
6.57 kB
import gradio
import torch.nn as nn
import torch
from torch_geometric.loader import DataLoader
import utils.clean_data as cd
import utils.shape_features as sf
import utils.node_features as nf
import utils.edge_features as ef
# from datetime import datetime
# start_time = datetime.now()
node_model_path = 'utils/emb_model/Node_64.pt'
edge_model_path = 'utils/emb_model/Edge_64.pt'
class InfoGraph(nn.Module):
def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
super(InfoGraph, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.prior = False
self.embedding_dim = mi_units = hidden_dim * num_gc_layers
self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)
self.local_d = FF(self.embedding_dim)
self.global_d = FF(self.embedding_dim)
# self.local_d = MI1x1ConvNet(self.embedding_dim, mi_units)
# self.global_d = MIFCNet(self.embedding_dim, mi_units)
if self.prior:
self.prior_d = PriorDiscriminator(self.embedding_dim)
self.init_emb()
def init_emb(self):
initrange = -1.5 / self.embedding_dim
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, x, edge_index, batch, num_graphs):
# batch_size = data.num_graphs
if x is None:
x = torch.ones(batch.shape[0]).to(device)
y, M = self.encoder(x, edge_index, batch)
g_enc = self.global_d(y)
l_enc = self.local_d(M)
mode='fd'
measure='JSD'
local_global_loss = local_global_loss_(l_enc, g_enc, edge_index, batch, measure)
if self.prior:
prior = torch.rand_like(y)
term_a = torch.log(self.prior_d(prior)).mean()
term_b = torch.log(1.0 - self.prior_d(y)).mean()
PRIOR = - (term_a + term_b) * self.gamma
else:
PRIOR = 0
return local_global_loss + PRIOR
def outline_embedding(wkt, wall):
wall_f, wkt_f = cd.read_wall_wkt(wall, wkt)
apa_wall, apa_geo = cd.clean_geometry(wall_f, wkt_f)
apa_geo = apa_geo
apa_line = apa_geo.boundary
apa_wall_O = cd.exterior_wall(apa_line, apa_wall)
apa_coor = cd.geo_coor(apa_geo)
xarr4cv, yarr4cv = apa_geo.exterior.coords.xy
x4cv = xarr4cv.tolist()
y4cv = yarr4cv.tolist()
scale = 100000
xmin_abs = abs(min(x4cv))
ymin_abs = abs(min(y4cv))
p_4_cv = cd.points4cv(x4cv, y4cv, xmin_abs, ymin_abs, scale)
grid_points = cd.gridpoints(apa_geo, 1)
Dir_S_longestedge, Dir_N_longestedge, Dir_W_longestedge, Dir_E_longestedge, Dir_S_max, Dir_N_max, Dir_W_max, Dir_E_max, Facade_length, Facade_ratio = sf.wall_direction_ratio(apa_line, apa_wall)
Perimeter = sf.apartment_perimeter(apa_geo)
Area = sf.apartment_area(apa_geo)
BBox_width_x, BBox_height_y, Aspect_ratio, Extent, ULC_x, ULC_y, LRC_x, LRC_y = sf.boundingbox_features(apa_geo)
Max_diameter = sf.max_diameter(apa_geo)
Fractality = sf.fractality(apa_geo)
Circularity = sf.circularity(apa_geo)
Outer_radius = sf.outer_radius(p_4_cv, xmin_abs, ymin_abs, scale)
Inner_radius = sf.inner_radius(apa_geo, apa_line)
Dist_mean, Dist_sigma, Roundness = sf.roundness_features(apa_line)
Compactness = sf.compactness(apa_geo)
Equivalent_diameter = sf.equivalent_diameter(apa_geo)
Shape_membership_index = sf.shape_membership_index(apa_line)
Convexity, Hull_geo = sf.convexity(p_4_cv, apa_geo, xmin_abs, ymin_abs, scale)
Rectangularity, Rect_phi, Rect_width, Rect_height = sf.rectangle_features(p_4_cv, apa_geo, xmin_abs, ymin_abs, scale)
Squareness = sf.squareness(apa_geo)
Moment_index = sf.moment_index(apa_geo, Convexity, Compactness)
nDetour_index = sf.ndetour_index(apa_geo, Hull_geo)
nCohesion_index = sf.ncohesion_index(apa_geo, grid_points)
nProximity_index, nSpin_index = sf.nproximity_nspin_index(apa_geo, grid_points)
nExchange_index = sf.nexchange_index(apa_geo)
nPerimeter_index = sf.nperimeter_index(apa_geo)
nDepth_index = sf.ndepth_index(apa_geo, apa_line, grid_points)
nGirth_index = sf.ngirth_index(apa_geo, Inner_radius)
nRange_index = sf.nrange_index(apa_geo, Outer_radius)
nTraversal_index = sf.ntraversal_index(apa_geo, apa_line)
shape = [Dir_S_longestedge, Dir_N_longestedge, Dir_W_longestedge, Dir_E_longestedge, Dir_S_max, Dir_N_max, Dir_W_max, Dir_E_max, Facade_length, Facade_ratio,
Perimeter, Area,
BBox_width_x, BBox_height_y, Aspect_ratio, Extent, ULC_x, ULC_y, LRC_x, LRC_y,
Max_diameter, Fractality, Circularity, Outer_radius, Inner_radius,
Dist_mean, Dist_sigma, Roundness,
Compactness, Equivalent_diameter, Shape_membership_index, Convexity,
Rectangularity, Rect_phi, Rect_width, Rect_height,
Squareness, Moment_index, nDetour_index, nCohesion_index,
nProximity_index, nExchange_index, nSpin_index, nPerimeter_index,
nDepth_index, nGirth_index, nRange_index, nTraversal_index]
shape = [float(i) for i in shape]
node_graph = nf.node_graph(apa_coor, apa_geo)
node_model = torch.load(node_model_path)
node_model.eval()
node_dataloader = DataLoader(node_graph, batch_size=1)
node_emb = node_model.encoder.get_embeddings(node_dataloader)
node = node_emb[0].tolist()
edge_graph = ef.edge_graph(apa_line, apa_wall)
edge_model = torch.load(edge_model_path)
edge_model.eval()
edge_dataloader = DataLoader(edge_graph, batch_size=1)
edge_emb = edge_model.encoder.get_embeddings(edge_dataloader)
edge = edge_emb[0].tolist()
json = {"edge": edge,
"shape": shape,
"node": node}
return json
gradio_interface = gradio.Interface(fn=outline_embedding,
inputs = [gradio.Textbox(type="text", label="wkt", placeholder="wkt"),
gradio.Textbox(type="text", label="wall", placeholder="wall")],
outputs = "json",
title="outline embedding")
# end_time = datetime.now()
# print('Duration: {}'.format(end_time - start_time))
# api_open=True,
gradio_interface.queue(max_size=5, status_update_rate="auto")
gradio_interface.launch(show_error=True, enable_queue=True)