Spaces:
Runtime error
Runtime error
Update
Browse files
model.py
CHANGED
@@ -215,7 +215,7 @@ class Model:
|
|
215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
216 |
|
217 |
elapsed = time.perf_counter() - start
|
218 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
219 |
return model, args
|
220 |
|
221 |
def load_strategy(self) -> CoglmStrategy:
|
@@ -229,7 +229,7 @@ class Model:
|
|
229 |
top_k_cluster=self.args.temp_cluster_gen)
|
230 |
|
231 |
elapsed = time.perf_counter() - start
|
232 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
233 |
return strategy
|
234 |
|
235 |
def load_srg(self) -> SRGroup:
|
@@ -239,7 +239,7 @@ class Model:
|
|
239 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
240 |
|
241 |
elapsed = time.perf_counter() - start
|
242 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
243 |
return srg
|
244 |
|
245 |
def update_style(self, style: str) -> None:
|
@@ -264,7 +264,7 @@ class Model:
|
|
264 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
265 |
|
266 |
elapsed = time.perf_counter() - start
|
267 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
268 |
|
269 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
270 |
num: int) -> list[np.ndarray] | None:
|
@@ -302,7 +302,7 @@ class Model:
|
|
302 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
303 |
|
304 |
elapsed = time.perf_counter() - start
|
305 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
306 |
return seq, txt_len
|
307 |
|
308 |
@torch.inference_mode()
|
@@ -340,7 +340,7 @@ class Model:
|
|
340 |
logger.debug(f'{output_tokens.shape=}')
|
341 |
|
342 |
elapsed = time.perf_counter() - start
|
343 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
344 |
return output_tokens
|
345 |
|
346 |
@staticmethod
|
@@ -374,7 +374,7 @@ class Model:
|
|
374 |
res.append(decoded_img) # only the last image (target)
|
375 |
|
376 |
elapsed = time.perf_counter() - start
|
377 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
378 |
return res
|
379 |
|
380 |
|
|
|
215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
216 |
|
217 |
elapsed = time.perf_counter() - start
|
218 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
219 |
return model, args
|
220 |
|
221 |
def load_strategy(self) -> CoglmStrategy:
|
|
|
229 |
top_k_cluster=self.args.temp_cluster_gen)
|
230 |
|
231 |
elapsed = time.perf_counter() - start
|
232 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
233 |
return strategy
|
234 |
|
235 |
def load_srg(self) -> SRGroup:
|
|
|
239 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
240 |
|
241 |
elapsed = time.perf_counter() - start
|
242 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
243 |
return srg
|
244 |
|
245 |
def update_style(self, style: str) -> None:
|
|
|
264 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
265 |
|
266 |
elapsed = time.perf_counter() - start
|
267 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
268 |
|
269 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
270 |
num: int) -> list[np.ndarray] | None:
|
|
|
302 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
303 |
|
304 |
elapsed = time.perf_counter() - start
|
305 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
306 |
return seq, txt_len
|
307 |
|
308 |
@torch.inference_mode()
|
|
|
340 |
logger.debug(f'{output_tokens.shape=}')
|
341 |
|
342 |
elapsed = time.perf_counter() - start
|
343 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
344 |
return output_tokens
|
345 |
|
346 |
@staticmethod
|
|
|
374 |
res.append(decoded_img) # only the last image (target)
|
375 |
|
376 |
elapsed = time.perf_counter() - start
|
377 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
378 |
return res
|
379 |
|
380 |
|