Spaces:
BAAI
/
Running on L40S

ryanzhangfan commited on
Commit
6380db8
1 Parent(s): 116a3d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -24,6 +24,8 @@ subprocess.run(
24
  shell=True,
25
  )
26
 
 
 
27
  # Model paths
28
  EMU_GEN_HUB = "BAAI/Emu3-Gen"
29
  EMU_CHAT_HUB = "BAAI/Emu3-Chat"
@@ -33,28 +35,28 @@ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
33
  # Emu3-Gen model and processor
34
  gen_model = AutoModelForCausalLM.from_pretrained(
35
  EMU_GEN_HUB,
36
- device_map="cuda:0",
37
  torch_dtype=torch.bfloat16,
38
  attn_implementation="flash_attention_2",
39
  trust_remote_code=True,
40
- )
41
 
42
  # Emu3-Chat model and processor
43
  chat_model = AutoModelForCausalLM.from_pretrained(
44
  EMU_CHAT_HUB,
45
- device_map="cuda:0",
46
  torch_dtype=torch.bfloat16,
47
  attn_implementation="flash_attention_2",
48
  trust_remote_code=True,
49
- )
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
52
  image_processor = AutoImageProcessor.from_pretrained(
53
  VQ_HUB, trust_remote_code=True
54
  )
55
  image_tokenizer = AutoModel.from_pretrained(
56
- VQ_HUB, device_map="cuda:0", trust_remote_code=True
57
- ).eval()
58
  processor = Emu3Processor(
59
  image_processor, image_tokenizer, tokenizer
60
  )
@@ -97,7 +99,7 @@ def generate_image(prompt):
97
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
98
  classifier_free_guidance,
99
  gen_model,
100
- unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
101
  ),
102
  PrefixConstrainedLogitsProcessor(
103
  constrained_fn,
@@ -108,7 +110,7 @@ def generate_image(prompt):
108
 
109
  # Generate
110
  outputs = gen_model.generate(
111
- pos_inputs.input_ids.to("cuda:0"),
112
  generation_config=GENERATION_CONFIG,
113
  logits_processor=logits_processor,
114
  )
@@ -139,7 +141,7 @@ def vision_language_understanding(image, text):
139
 
140
  # Generate
141
  outputs = chat_model.generate(
142
- inputs.input_ids.to("cuda:0"),
143
  generation_config=GENERATION_CONFIG,
144
  max_new_tokens=320,
145
  )
 
24
  shell=True,
25
  )
26
 
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
  # Model paths
30
  EMU_GEN_HUB = "BAAI/Emu3-Gen"
31
  EMU_CHAT_HUB = "BAAI/Emu3-Chat"
 
35
  # Emu3-Gen model and processor
36
  gen_model = AutoModelForCausalLM.from_pretrained(
37
  EMU_GEN_HUB,
38
+ device_map="cpu",
39
  torch_dtype=torch.bfloat16,
40
  attn_implementation="flash_attention_2",
41
  trust_remote_code=True,
42
+ ).to(device)
43
 
44
  # Emu3-Chat model and processor
45
  chat_model = AutoModelForCausalLM.from_pretrained(
46
  EMU_CHAT_HUB,
47
+ device_map="cpu",
48
  torch_dtype=torch.bfloat16,
49
  attn_implementation="flash_attention_2",
50
  trust_remote_code=True,
51
+ ).to(device)
52
 
53
  tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True)
54
  image_processor = AutoImageProcessor.from_pretrained(
55
  VQ_HUB, trust_remote_code=True
56
  )
57
  image_tokenizer = AutoModel.from_pretrained(
58
+ VQ_HUB, device_map="cpu", trust_remote_code=True
59
+ ).eval().to(device)
60
  processor = Emu3Processor(
61
  image_processor, image_tokenizer, tokenizer
62
  )
 
99
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
100
  classifier_free_guidance,
101
  gen_model,
102
+ unconditional_ids=neg_inputs.input_ids.to(device),
103
  ),
104
  PrefixConstrainedLogitsProcessor(
105
  constrained_fn,
 
110
 
111
  # Generate
112
  outputs = gen_model.generate(
113
+ pos_inputs.input_ids.to(device),
114
  generation_config=GENERATION_CONFIG,
115
  logits_processor=logits_processor,
116
  )
 
141
 
142
  # Generate
143
  outputs = chat_model.generate(
144
+ inputs.input_ids.to(device),
145
  generation_config=GENERATION_CONFIG,
146
  max_new_tokens=320,
147
  )