Spaces:
Runtime error
Runtime error
Make the second stage model available on HF Space
Browse files
app.py
CHANGED
@@ -8,12 +8,9 @@ from model import AppModel
|
|
8 |
|
9 |
DESCRIPTION = '''# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)
|
10 |
|
11 |
-
|
12 |
-
This application accepts English or Chinese as input.
|
13 |
In general, Chinese input produces better results than English input.
|
14 |
-
|
15 |
-
But the translation model may mistranslate and the results could be poor.
|
16 |
-
So, it is also a good idea to input the translation results from other translation services.
|
17 |
'''
|
18 |
NOTES = '''
|
19 |
- This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
|
@@ -29,7 +26,7 @@ def set_example_text(example: list) -> list[dict]:
|
|
29 |
|
30 |
|
31 |
def main():
|
32 |
-
only_first_stage =
|
33 |
max_inference_batch_size = 8
|
34 |
model = AppModel(max_inference_batch_size, only_first_stage)
|
35 |
|
|
|
8 |
|
9 |
DESCRIPTION = '''# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)
|
10 |
|
11 |
+
The model accepts English or Chinese as input.
|
|
|
12 |
In general, Chinese input produces better results than English input.
|
13 |
+
By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with this Space will be used as input. Since the translation model may mistranslate, you may want to use the translation results from other translation services.
|
|
|
|
|
14 |
'''
|
15 |
NOTES = '''
|
16 |
- This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
|
|
|
26 |
|
27 |
|
28 |
def main():
|
29 |
+
only_first_stage = False
|
30 |
max_inference_batch_size = 8
|
31 |
model = AppModel(max_inference_batch_size, only_first_stage)
|
32 |
|
model.py
CHANGED
@@ -54,7 +54,7 @@ if os.getenv('SYSTEM') == 'spaces':
|
|
54 |
names = [
|
55 |
'coglm.zip',
|
56 |
'cogview2-dsr.zip',
|
57 |
-
|
58 |
]
|
59 |
for name in names:
|
60 |
download_and_extract_cogview2_models(name)
|
@@ -215,6 +215,8 @@ class Model:
|
|
215 |
start = time.perf_counter()
|
216 |
|
217 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
|
|
|
|
218 |
|
219 |
elapsed = time.perf_counter() - start
|
220 |
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
@@ -278,8 +280,20 @@ class Model:
|
|
278 |
seq, txt_len = self.preprocess_text(text)
|
279 |
if seq is None:
|
280 |
return None
|
|
|
281 |
self.only_first_stage = only_first_stage
|
|
|
|
|
|
|
|
|
|
|
282 |
tokens = self.generate_tokens(seq, txt_len, num)
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
res = self.generate_images(seq, txt_len, tokens)
|
284 |
|
285 |
elapsed = time.perf_counter() - start
|
|
|
54 |
names = [
|
55 |
'coglm.zip',
|
56 |
'cogview2-dsr.zip',
|
57 |
+
'cogview2-itersr.zip',
|
58 |
]
|
59 |
for name in names:
|
60 |
download_and_extract_cogview2_models(name)
|
|
|
215 |
start = time.perf_counter()
|
216 |
|
217 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
218 |
+
if not self.args.only_first_stage:
|
219 |
+
model.transformer.cpu()
|
220 |
|
221 |
elapsed = time.perf_counter() - start
|
222 |
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
|
|
280 |
seq, txt_len = self.preprocess_text(text)
|
281 |
if seq is None:
|
282 |
return None
|
283 |
+
|
284 |
self.only_first_stage = only_first_stage
|
285 |
+
if not self.only_first_stage or self.srg is not None:
|
286 |
+
self.srg.dsr.model.cpu()
|
287 |
+
self.srg.itersr.model.cpu()
|
288 |
+
torch.cuda.empty_cache()
|
289 |
+
self.model.transformer.to(self.device)
|
290 |
tokens = self.generate_tokens(seq, txt_len, num)
|
291 |
+
|
292 |
+
if not self.only_first_stage:
|
293 |
+
self.model.transformer.cpu()
|
294 |
+
torch.cuda.empty_cache()
|
295 |
+
self.srg.dsr.model.to(self.device)
|
296 |
+
self.srg.itersr.model.to(self.device)
|
297 |
res = self.generate_images(seq, txt_len, tokens)
|
298 |
|
299 |
elapsed = time.perf_counter() - start
|