Spaces:
Runtime error
Runtime error
Update demo_model.py
Browse files- demo_model.py +1 -1
demo_model.py
CHANGED
@@ -79,7 +79,7 @@ class LGGMText2Graph_Demo(pl.LightningModule):
|
|
79 |
with torch.no_grad():
|
80 |
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
|
81 |
|
82 |
-
samples = self.sample_batch(
|
83 |
|
84 |
nx_graphs = []
|
85 |
for graph in samples:
|
|
|
79 |
with torch.no_grad():
|
80 |
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
|
81 |
|
82 |
+
samples = self.sample_batch(3, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes)
|
83 |
|
84 |
nx_graphs = []
|
85 |
for graph in samples:
|