File size: 2,253 Bytes
181711a
 
 
 
 
 
 
bf36e69
181711a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import AutoModelForCausalLM, AutoTokenizer
import os


model_name = os.path.join(os.path.dirname(__file__), "Qwen25llm")


print("Loading Qwen2.5 files...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)  # load model
tokenizer = AutoTokenizer.from_pretrained(model_name)


def describe(prompt, system_prompt):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]  # construct msgs
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )  # get text
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )  # generate
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]  # generate
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response


def discriminate(class_name, prompt):
    system_prompt = "You are an accurate discriminator. " \
                    "You need to determines if the class name matches the description. " \
                    "Answer with YES or NO."
    keywords = [word for word in prompt.split(" ")
                if "select" in word or "classif" in word or "find" in word or "all" in word]
    if len(keywords) == 0:
        description = prompt
    else:  # # len(keywords > 0)
        description = prompt.rsplit(keywords[-1], 1)[-1]
    prompt = f"Does the {class_name} belong to \"{description}\"? \n\nAnswer me with YES or NO."
    result = describe(prompt, system_prompt)
    if "NO" in result or "no" in result or "No" in result:
        return False
    else:  # assert YES in result
        return True


def get_embedding(prompt):
    class_names = ("airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
    results = []
    for class_name in class_names:
        result = discriminate(class_name, prompt)
        results.append(result)
    return results