Wendy-Fly commited on
Commit
bc41164
·
verified ·
1 Parent(s): 51ae8ad

Upload generate_prompt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate_prompt.py +104 -0
generate_prompt.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ import json
7
+ from tqdm import tqdm
8
+ import os
9
+
10
+ def read_json(file_path):
11
+ with open(file_path, 'r', encoding='utf-8') as file:
12
+ data = json.load(file)
13
+ return data
14
+
15
+ def write_json(file_path, data):
16
+ with open(file_path, 'w', encoding='utf-8') as file:
17
+ json.dump(data, file, ensure_ascii=False, indent=4)
18
+
19
+ # default: Load the model on the available device(s)
20
+ print(torch.cuda.device_count())
21
+ model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt"
22
+ # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
+ # model_path, torch_dtype="auto", device_map="auto"
24
+ # )
25
+
26
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
27
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28
+ model_path,
29
+ torch_dtype=torch.bfloat16,
30
+ attn_implementation="flash_attention_2",
31
+ device_map="auto",
32
+ )
33
+
34
+ # default processor
35
+ processor = AutoProcessor.from_pretrained(model_path)
36
+ print(model.device)
37
+
38
+
39
+
40
+
41
+ data = read_json('/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/LLaMA-Factory/data/Percption.json')
42
+ save_data = []
43
+ correct_num = 0
44
+ begin = 0
45
+ end = len(data)
46
+ batch_size = 1
47
+ for batch_idx in tqdm(range(begin, end, batch_size)):
48
+ batch = data[batch_idx:batch_idx + batch_size]
49
+
50
+ image_list = []
51
+ input_text_list = []
52
+ data_list = []
53
+ save_list = []
54
+ sd_ans = []
55
+ # while True:
56
+ for idx, i in enumerate(batch):
57
+ save_ = {
58
+ "role": "user",
59
+ "content": [
60
+ {
61
+ "type": "image",
62
+ "image": img_path,
63
+ },
64
+ {"type": "text", "text": "Please help me write a prompt for image editing on this picture. The requirements are as follows: complex editing instructions should include two to five simple editing instructions involving spatial relationships (simple editing instructions such as ADD: add an object to the left of a certain object, DELETE: delete a certain object, MODIFY: change a certain object into another object). We hope that the editing instructions can have simple reasoning and can also include some abstract concept-based editing (such as making the atmosphere more romantic, or making the diet healthier, or making the boy more handsome and the girl more beautiful, etc.). Please give me clear editing instructions and also consider whether such editing instructions are reasonable."},
65
+ ],
66
+ "result":""
67
+ }
68
+ idx_real = batch_idx * batch_size + idx
69
+ messages = data[idx_real]
70
+ save_['content'][1]['image'] = messages['content'][1]['image']
71
+ save_['content'][2]['text'] = messages['content'][2]['text']
72
+
73
+ data_list.append(messages)
74
+ save_list.append(save_)
75
+
76
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
77
+ image_inputs, video_inputs = process_vision_info(messages)
78
+ inputs = processor(
79
+ text=[text],
80
+ images=image_inputs,
81
+ videos=video_inputs,
82
+ padding=True,
83
+ return_tensors="pt",
84
+ )
85
+ inputs = inputs.to(model.device)
86
+
87
+ # Inference: Generation of the output
88
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
89
+ generated_ids_trimmed = [
90
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
91
+ ]
92
+ output_text = processor.batch_decode(
93
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
94
+ )
95
+
96
+ for idx,x in enumerate(output_text):
97
+ idx_real = batch_idx * batch_size + idx
98
+ save_list[idx_real]['result'] = x
99
+ save_data.append(save_list[idx_real])
100
+
101
+ json_path = "image_path = '/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/gen.json'"
102
+ write_json(json_path,save_data)
103
+
104
+