Spaces:
Runtime error
Runtime error
liuhaotian
commited on
Commit
Β·
087de09
1
Parent(s):
01d67d8
Fix
Browse files- app.py +2 -2
- dataset/tsv_dataset.py +1 -1
- gligen/projection_matrix.pth +3 -0
- gligen/task_grounded_generation.py +2 -9
app.py
CHANGED
@@ -27,8 +27,8 @@ def parse_option():
|
|
27 |
parser.add_argument("--guidance_scale", type=float, default=5, help="")
|
28 |
parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
|
29 |
parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
|
30 |
-
parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=
|
31 |
-
parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=
|
32 |
args = parser.parse_args()
|
33 |
return args
|
34 |
args = parse_option()
|
|
|
27 |
parser.add_argument("--guidance_scale", type=float, default=5, help="")
|
28 |
parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
|
29 |
parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
|
30 |
+
parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=True, help="Load text-box inpainting pipeline.")
|
31 |
+
parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=True, help="Load text-image-box generation pipeline.")
|
32 |
args = parser.parse_args()
|
33 |
return args
|
34 |
args = parse_option()
|
dataset/tsv_dataset.py
CHANGED
@@ -190,7 +190,7 @@ class TSVDataset(BaseDataset):
|
|
190 |
self.which_layer_image = which_layer[1]
|
191 |
|
192 |
#self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
|
193 |
-
self.projection_matrix = torch.load('projection_matrix')
|
194 |
|
195 |
# Load tsv data
|
196 |
self.tsv_file = TSVFile(self.tsv_path)
|
|
|
190 |
self.which_layer_image = which_layer[1]
|
191 |
|
192 |
#self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
|
193 |
+
self.projection_matrix = torch.load('projection_matrix.pth')
|
194 |
|
195 |
# Load tsv data
|
196 |
self.tsv_file = TSVFile(self.tsv_path)
|
gligen/projection_matrix.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:819d51fde084e16e5960323c8bafba07fa8ee727e5403e5e4bdced4333c68faa
|
3 |
+
size 2360043
|
gligen/task_grounded_generation.py
CHANGED
@@ -107,7 +107,7 @@ def get_clip_feature(model, processor, input, is_image=False):
|
|
107 |
if feature_type[1] == 'after_renorm':
|
108 |
feature = feature*28.7
|
109 |
if feature_type[1] == 'after_reproject':
|
110 |
-
feature = project( feature, torch.load('gligen/projection_matrix').cuda().T ).squeeze(0)
|
111 |
feature = ( feature / feature.norm() ) * 28.7
|
112 |
feature = feature.unsqueeze(0)
|
113 |
else:
|
@@ -249,16 +249,9 @@ def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs):
|
|
249 |
|
250 |
|
251 |
# ------------- other logistics ------------- #
|
252 |
-
os.makedirs( os.path.join(save_folder, 'images'), exist_ok=True)
|
253 |
-
os.makedirs( os.path.join(save_folder, 'layout'), exist_ok=True)
|
254 |
-
os.makedirs( os.path.join(save_folder, 'overlay'), exist_ok=True)
|
255 |
-
|
256 |
-
start = len( os.listdir(os.path.join(save_folder, 'images')) )
|
257 |
-
image_ids = list(range(start,start+batch_size))
|
258 |
-
print(image_ids)
|
259 |
|
260 |
sample_list = []
|
261 |
-
for
|
262 |
sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
|
263 |
sample = sample.cpu().numpy().transpose(1,2,0) * 255
|
264 |
sample = Image.fromarray(sample.astype(np.uint8))
|
|
|
107 |
if feature_type[1] == 'after_renorm':
|
108 |
feature = feature*28.7
|
109 |
if feature_type[1] == 'after_reproject':
|
110 |
+
feature = project( feature, torch.load('gligen/projection_matrix.pth').cuda().T ).squeeze(0)
|
111 |
feature = ( feature / feature.norm() ) * 28.7
|
112 |
feature = feature.unsqueeze(0)
|
113 |
else:
|
|
|
249 |
|
250 |
|
251 |
# ------------- other logistics ------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
sample_list = []
|
254 |
+
for sample in samples_fake:
|
255 |
sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
|
256 |
sample = sample.cpu().numpy().transpose(1,2,0) * 255
|
257 |
sample = Image.fromarray(sample.astype(np.uint8))
|