hysts HF staff commited on
Commit
ac14842
1 Parent(s): 7cebd9b

Make the second stage model available on HF Space

Browse files
Files changed (2) hide show
  1. app.py +3 -6
  2. model.py +15 -1
app.py CHANGED
@@ -8,12 +8,9 @@ from model import AppModel
8
 
9
  DESCRIPTION = '''# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)
10
 
11
- This Spaces demo runs only one of the two stages the CogView2 codebase has, due to GPU hardware limitations, with that the outputs may not match the original codebase/paper
12
- This application accepts English or Chinese as input.
13
  In general, Chinese input produces better results than English input.
14
- If you check the "Translate to Chinese" checkbox, the app will use the English to Chinese translation results with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) as input.
15
- But the translation model may mistranslate and the results could be poor.
16
- So, it is also a good idea to input the translation results from other translation services.
17
  '''
18
  NOTES = '''
19
  - This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
@@ -29,7 +26,7 @@ def set_example_text(example: list) -> list[dict]:
29
 
30
 
31
  def main():
32
- only_first_stage = True
33
  max_inference_batch_size = 8
34
  model = AppModel(max_inference_batch_size, only_first_stage)
35
 
 
8
 
9
  DESCRIPTION = '''# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)
10
 
11
+ The model accepts English or Chinese as input.
 
12
  In general, Chinese input produces better results than English input.
13
+ By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with this Space will be used as input. Since the translation model may mistranslate, you may want to use the translation results from other translation services.
 
 
14
  '''
15
  NOTES = '''
16
  - This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
 
26
 
27
 
28
  def main():
29
+ only_first_stage = False
30
  max_inference_batch_size = 8
31
  model = AppModel(max_inference_batch_size, only_first_stage)
32
 
model.py CHANGED
@@ -54,7 +54,7 @@ if os.getenv('SYSTEM') == 'spaces':
54
  names = [
55
  'coglm.zip',
56
  'cogview2-dsr.zip',
57
- #'cogview2-itersr.zip',
58
  ]
59
  for name in names:
60
  download_and_extract_cogview2_models(name)
@@ -215,6 +215,8 @@ class Model:
215
  start = time.perf_counter()
216
 
217
  model, args = InferenceModel.from_pretrained(self.args, 'coglm')
 
 
218
 
219
  elapsed = time.perf_counter() - start
220
  logger.info(f'--- done ({elapsed=:.3f}) ---')
@@ -278,8 +280,20 @@ class Model:
278
  seq, txt_len = self.preprocess_text(text)
279
  if seq is None:
280
  return None
 
281
  self.only_first_stage = only_first_stage
 
 
 
 
 
282
  tokens = self.generate_tokens(seq, txt_len, num)
 
 
 
 
 
 
283
  res = self.generate_images(seq, txt_len, tokens)
284
 
285
  elapsed = time.perf_counter() - start
 
54
  names = [
55
  'coglm.zip',
56
  'cogview2-dsr.zip',
57
+ 'cogview2-itersr.zip',
58
  ]
59
  for name in names:
60
  download_and_extract_cogview2_models(name)
 
215
  start = time.perf_counter()
216
 
217
  model, args = InferenceModel.from_pretrained(self.args, 'coglm')
218
+ if not self.args.only_first_stage:
219
+ model.transformer.cpu()
220
 
221
  elapsed = time.perf_counter() - start
222
  logger.info(f'--- done ({elapsed=:.3f}) ---')
 
280
  seq, txt_len = self.preprocess_text(text)
281
  if seq is None:
282
  return None
283
+
284
  self.only_first_stage = only_first_stage
285
+ if not self.only_first_stage or self.srg is not None:
286
+ self.srg.dsr.model.cpu()
287
+ self.srg.itersr.model.cpu()
288
+ torch.cuda.empty_cache()
289
+ self.model.transformer.to(self.device)
290
  tokens = self.generate_tokens(seq, txt_len, num)
291
+
292
+ if not self.only_first_stage:
293
+ self.model.transformer.cpu()
294
+ torch.cuda.empty_cache()
295
+ self.srg.dsr.model.to(self.device)
296
+ self.srg.itersr.model.to(self.device)
297
  res = self.generate_images(seq, txt_len, tokens)
298
 
299
  elapsed = time.perf_counter() - start