wcy1122 commited on
Commit
0ade547
·
1 Parent(s): 162e6d1

update code

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -20,6 +20,7 @@ def extract_gen_content(text):
20
 
21
  def _load_model_processor():
22
 
 
23
  local_dir = snapshot_download(
24
  repo_id="xiabs/DreamOmni2",
25
  revision="main",
@@ -32,13 +33,14 @@ def _load_model_processor():
32
  pipe = DreamOmni2Pipeline.from_pretrained(
33
  "black-forest-labs/FLUX.1-Kontext-dev",
34
  torch_dtype=torch.bfloat16
35
- )
36
  pipe.load_lora_weights(lora_dir, adapter_name="edit")
37
  pipe.set_adapters(["edit"], adapter_weights=[1])
38
 
39
  vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
40
  vlm_dir,
41
- torch_dtype="bfloat16"
 
42
  )
43
  processor = AutoProcessor.from_pretrained(vlm_dir)
44
  return vlm_model, processor, pipe
@@ -59,7 +61,8 @@ def _launch_demo(vlm_model, processor, pipe):
59
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
  image_inputs, video_inputs = process_vision_info(messages)
61
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
62
- inputs = inputs.to("cuda")
 
63
 
64
  generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096)
65
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
@@ -116,7 +119,7 @@ def _launch_demo(vlm_model, processor, pipe):
116
  image.save(output_path)
117
  print(f"Edit result saved to {output_path}")
118
 
119
-
120
  def process_request(image_file_1, image_file_2, instruction):
121
  # debugpy.listen(5678)
122
  # print("Waiting for debugger attach...")
 
20
 
21
  def _load_model_processor():
22
 
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
  local_dir = snapshot_download(
25
  repo_id="xiabs/DreamOmni2",
26
  revision="main",
 
33
  pipe = DreamOmni2Pipeline.from_pretrained(
34
  "black-forest-labs/FLUX.1-Kontext-dev",
35
  torch_dtype=torch.bfloat16
36
+ ).to(device)
37
  pipe.load_lora_weights(lora_dir, adapter_name="edit")
38
  pipe.set_adapters(["edit"], adapter_weights=[1])
39
 
40
  vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
41
  vlm_dir,
42
+ torch_dtype="bfloat16",
43
+ device=device
44
  )
45
  processor = AutoProcessor.from_pretrained(vlm_dir)
46
  return vlm_model, processor, pipe
 
61
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
  image_inputs, video_inputs = process_vision_info(messages)
63
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
64
+ inputs = inputs.to(device=vlm_model.device)
65
+ print(vlm_model.device, '++++')
66
 
67
  generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096)
68
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
 
119
  image.save(output_path)
120
  print(f"Edit result saved to {output_path}")
121
 
122
+ @spaces.GPU()
123
  def process_request(image_file_1, image_file_2, instruction):
124
  # debugpy.listen(5678)
125
  # print("Waiting for debugger attach...")