Update README.md
Browse files
README.md
CHANGED
@@ -12,3 +12,99 @@ Please refer to [our codebase](https://github.com/NoviScl/Design2Code/tree/main/
|
|
12 |
|
13 |
Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0).
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0).
|
14 |
|
15 |
+
# Example Usage (based on SAT)
|
16 |
+
|
17 |
+
```
|
18 |
+
import sys
|
19 |
+
sys.path.insert(1, '/path/to/CogVLM')
|
20 |
+
from sat.model import AutoModel
|
21 |
+
import argparse
|
22 |
+
from utils.models import CogAgentModel, CogVLMModel, FineTuneTestCogAgentModel
|
23 |
+
import torch
|
24 |
+
from sat.model.mixins import CachedAutoregressiveMixin
|
25 |
+
from sat.quantization.kernels import quantize
|
26 |
+
from sat.model import AutoModel
|
27 |
+
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
|
28 |
+
from utils.models import CogAgentModel, CogVLMModel
|
29 |
+
from tqdm import tqdm
|
30 |
+
import os
|
31 |
+
import argparse
|
32 |
+
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument('--temperature', type=float, default=0.5)
|
35 |
+
parser.add_argument('--repetition_penalty', type=float, default=1.1)
|
36 |
+
parser.add_argument('--split_num', type=int, default=0)
|
37 |
+
args = parser.parse_args()
|
38 |
+
args.bf16 = True
|
39 |
+
args.stream_chat = False
|
40 |
+
args.version = "chat"
|
41 |
+
|
42 |
+
# You can download the testset from https://huggingface.co/datasets/SALT-NLP/Design2Code
|
43 |
+
test_data_dir = "/path/to/Design2Code"
|
44 |
+
predictions_dir = "/path/to/design2code_18b_v0_predictions"
|
45 |
+
if not os.path.exists(predictions_dir):
|
46 |
+
try:
|
47 |
+
os.makedirs(predictions_dir)
|
48 |
+
except:
|
49 |
+
pass
|
50 |
+
|
51 |
+
filename_list = [filename for filename in os.listdir(test_data_dir) if filename.endswith(".png") and int(filename[:-4]) % 4 == args.split_num]
|
52 |
+
world_size = 1
|
53 |
+
model, model_args = FineTuneTestCogAgentModel.from_pretrained(
|
54 |
+
f"/path/to/design2code-18b-v0",
|
55 |
+
args=argparse.Namespace(
|
56 |
+
deepspeed=None,
|
57 |
+
local_rank=0,
|
58 |
+
rank=0,
|
59 |
+
world_size=1,
|
60 |
+
model_parallel_size=1,
|
61 |
+
mode='inference',
|
62 |
+
skip_init=True,
|
63 |
+
use_gpu_initialization=True,
|
64 |
+
device='cuda',
|
65 |
+
bf16=True,
|
66 |
+
fp16=None), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
|
67 |
+
model = model.eval()
|
68 |
+
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
|
69 |
+
|
70 |
+
language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
|
71 |
+
print("[Language processor version]:", language_processor_version)
|
72 |
+
tokenizer = llama2_tokenizer("lmsys/vicuna-7b-v1.5", signal_type=language_processor_version)
|
73 |
+
image_processor = get_image_processor(model_args.eva_args["image_size"][0])
|
74 |
+
cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
|
75 |
+
text_processor_infer = llama2_text_processor_inference(tokenizer, 2048, model.image_length)
|
76 |
+
|
77 |
+
def get_html(image_path):
|
78 |
+
with torch.no_grad():
|
79 |
+
history = None
|
80 |
+
cache_image = None
|
81 |
+
# We use an empty string as the query
|
82 |
+
query = ''
|
83 |
+
|
84 |
+
response, history, cache_image = chat(
|
85 |
+
image_path,
|
86 |
+
model,
|
87 |
+
text_processor_infer,
|
88 |
+
image_processor,
|
89 |
+
query,
|
90 |
+
history=history,
|
91 |
+
cross_img_processor=cross_image_processor,
|
92 |
+
image=cache_image,
|
93 |
+
max_length=4096,
|
94 |
+
top_p=1.0,
|
95 |
+
temperature=args.temperature,
|
96 |
+
top_k=1,
|
97 |
+
invalid_slices=text_processor_infer.invalid_slices,
|
98 |
+
repetition_penalty=args.repetition_penalty,
|
99 |
+
args=args
|
100 |
+
)
|
101 |
+
|
102 |
+
return response
|
103 |
+
|
104 |
+
for filename in tqdm(filename_list):
|
105 |
+
image_path = os.path.join(test_data_dir, filename)
|
106 |
+
generated_text = get_html(image_path)
|
107 |
+
with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f:
|
108 |
+
f.write(generated_text)
|
109 |
+
```
|
110 |
+
|