File size: 4,057 Bytes
01ec35c
 
bda4d53
 
01ec35c
bda4d53
6045edf
bda4d53
 
 
37fcd48
62b36fc
 
 
733addb
 
7c096cc
733addb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b490472
733addb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
---
license: apache-2.0
datasets:
- HuggingFaceM4/WebSight
---

The model is [CogAgent-chat-18B](https://huggingface.co/THUDM/CogAgent) finetuned (LoRA with rank 8 added to the language decoder) on 160K WebSight examples.

The model is in the format of [SAT (SwissArmyTransformer)](https://github.com/THUDM/SwissArmyTransformer/).

Please refer to [our paper](https://arxiv.org/abs/2403.03163) and [our codebase](https://github.com/NoviScl/Design2Code/tree/main/CogVLM) to run inference.

Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0).

# Example Usage (based on SAT)

```python
import sys
sys.path.insert(1, '/path/to/CogVLM')
from sat.model import AutoModel
import argparse
from utils.models import CogAgentModel, CogVLMModel, FineTuneTestCogAgentModel
import torch
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
from sat.model import AutoModel
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
from utils.models import CogAgentModel, CogVLMModel
from tqdm import tqdm 
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--temperature', type=float, default=0.5)
parser.add_argument('--repetition_penalty', type=float, default=1.1)
args = parser.parse_args()
args.bf16 = True
args.stream_chat = False
args.version = "chat"

# You can download the testset from https://huggingface.co/datasets/SALT-NLP/Design2Code
test_data_dir = "/path/to/Design2Code"
predictions_dir = "/path/to/design2code_18b_v0_predictions"
if not os.path.exists(predictions_dir):
    try:
        os.makedirs(predictions_dir)
    except:
        pass

filename_list = [filename for filename in os.listdir(test_data_dir) if filename.endswith(".png")]
world_size = 1
model, model_args = FineTuneTestCogAgentModel.from_pretrained(
        f"/path/to/design2code-18b-v0",
        args=argparse.Namespace(
        deepspeed=None,
        local_rank=0,
        rank=0,
        world_size=1,
        model_parallel_size=1,
        mode='inference',
        skip_init=True,
        use_gpu_initialization=True,
        device='cuda',
        bf16=True,
        fp16=None), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
model = model.eval()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())

language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
print("[Language processor version]:", language_processor_version)
tokenizer = llama2_tokenizer("lmsys/vicuna-7b-v1.5", signal_type=language_processor_version)
image_processor = get_image_processor(model_args.eva_args["image_size"][0])
cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
text_processor_infer = llama2_text_processor_inference(tokenizer, 2048, model.image_length)

def get_html(image_path):
    with torch.no_grad():
        history = None
        cache_image = None
        # We use an empty string as the query
        query = ''
    
        response, history, cache_image = chat(
            image_path,
            model,
            text_processor_infer,
            image_processor,
            query,
            history=history,
            cross_img_processor=cross_image_processor,
            image=cache_image,
            max_length=4096,
            top_p=1.0,
            temperature=args.temperature,
            top_k=1,
            invalid_slices=text_processor_infer.invalid_slices,
            repetition_penalty=args.repetition_penalty,
            args=args
        )
    
    return response

for filename in tqdm(filename_list):
    image_path = os.path.join(test_data_dir, filename)
    generated_text = get_html(image_path)
    with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f:
        f.write(generated_text)
```