zyanzhe commited on
Commit
733addb
1 Parent(s): 6045edf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -0
README.md CHANGED
@@ -12,3 +12,99 @@ Please refer to [our codebase](https://github.com/NoviScl/Design2Code/tree/main/
12
 
13
  Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0).
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0).
14
 
15
+ # Example Usage (based on SAT)
16
+
17
+ ```
18
+ import sys
19
+ sys.path.insert(1, '/path/to/CogVLM')
20
+ from sat.model import AutoModel
21
+ import argparse
22
+ from utils.models import CogAgentModel, CogVLMModel, FineTuneTestCogAgentModel
23
+ import torch
24
+ from sat.model.mixins import CachedAutoregressiveMixin
25
+ from sat.quantization.kernels import quantize
26
+ from sat.model import AutoModel
27
+ from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
28
+ from utils.models import CogAgentModel, CogVLMModel
29
+ from tqdm import tqdm
30
+ import os
31
+ import argparse
32
+
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--temperature', type=float, default=0.5)
35
+ parser.add_argument('--repetition_penalty', type=float, default=1.1)
36
+ parser.add_argument('--split_num', type=int, default=0)
37
+ args = parser.parse_args()
38
+ args.bf16 = True
39
+ args.stream_chat = False
40
+ args.version = "chat"
41
+
42
+ # You can download the testset from https://huggingface.co/datasets/SALT-NLP/Design2Code
43
+ test_data_dir = "/path/to/Design2Code"
44
+ predictions_dir = "/path/to/design2code_18b_v0_predictions"
45
+ if not os.path.exists(predictions_dir):
46
+ try:
47
+ os.makedirs(predictions_dir)
48
+ except:
49
+ pass
50
+
51
+ filename_list = [filename for filename in os.listdir(test_data_dir) if filename.endswith(".png") and int(filename[:-4]) % 4 == args.split_num]
52
+ world_size = 1
53
+ model, model_args = FineTuneTestCogAgentModel.from_pretrained(
54
+ f"/path/to/design2code-18b-v0",
55
+ args=argparse.Namespace(
56
+ deepspeed=None,
57
+ local_rank=0,
58
+ rank=0,
59
+ world_size=1,
60
+ model_parallel_size=1,
61
+ mode='inference',
62
+ skip_init=True,
63
+ use_gpu_initialization=True,
64
+ device='cuda',
65
+ bf16=True,
66
+ fp16=None), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
67
+ model = model.eval()
68
+ model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
69
+
70
+ language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
71
+ print("[Language processor version]:", language_processor_version)
72
+ tokenizer = llama2_tokenizer("lmsys/vicuna-7b-v1.5", signal_type=language_processor_version)
73
+ image_processor = get_image_processor(model_args.eva_args["image_size"][0])
74
+ cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
75
+ text_processor_infer = llama2_text_processor_inference(tokenizer, 2048, model.image_length)
76
+
77
+ def get_html(image_path):
78
+ with torch.no_grad():
79
+ history = None
80
+ cache_image = None
81
+ # We use an empty string as the query
82
+ query = ''
83
+
84
+ response, history, cache_image = chat(
85
+ image_path,
86
+ model,
87
+ text_processor_infer,
88
+ image_processor,
89
+ query,
90
+ history=history,
91
+ cross_img_processor=cross_image_processor,
92
+ image=cache_image,
93
+ max_length=4096,
94
+ top_p=1.0,
95
+ temperature=args.temperature,
96
+ top_k=1,
97
+ invalid_slices=text_processor_infer.invalid_slices,
98
+ repetition_penalty=args.repetition_penalty,
99
+ args=args
100
+ )
101
+
102
+ return response
103
+
104
+ for filename in tqdm(filename_list):
105
+ image_path = os.path.join(test_data_dir, filename)
106
+ generated_text = get_html(image_path)
107
+ with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f:
108
+ f.write(generated_text)
109
+ ```
110
+