File size: 4,847 Bytes
56c4b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from anthropic import Anthropic
from google import genai
from google.genai import types
from openai import OpenAI
import time


def get_client(messages, cfg):
    if 'gpt' in cfg.model.family_name or cfg.model.family_name == 'o':
        client = OpenAI(api_key=cfg.model.api_key)
    elif 'claude' in cfg.model.family_name:
        client = Anthropic(api_key=cfg.model.api_key)
    elif 'deepseek' in cfg.model.family_name:
        client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url)
    elif 'gemini' in cfg.model.family_name:
        client = genai.Client(api_key=cfg.model.api_key)
    elif cfg.model.family_name == 'qwen':
        client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url)
    else:
        raise ValueError(f'Model {cfg.model.family_name} not recognized')
    return client


def generate_response(messages, cfg):
    client = get_client(messages, cfg)
    model_name = cfg.model.name
    if 'o1' in model_name or 'o3' in model_name or 'o4' in model_name:
        # Need to follow the restrictions
        if 'o1' in model_name and len(messages)>0 and messages[0]['role'] == 'system':
            system_prompt = messages[0]['content']
            messages = messages[1:]
            messages[0]['content'] = system_prompt + messages[0]['content']
        # TODO: add these to the hydra config
        num_tokens = 16384
        temperature = 1.0
        start_time = time.time()
        response = client.chat.completions.create(
            model=model_name,
            messages=messages,
            max_completion_tokens=num_tokens,
            temperature=temperature)
        end_time = time.time()
        print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
        return response
    
    if 'claude' in model_name:
        num_tokens = 8192  # Give claude more tokens
        temperature = 0.7

        if len(messages)>0 and messages[0]['role'] == 'system':
            system_prompt = messages[0]['content']
            messages = messages[1:]

        start_time = time.time() 
        if cfg.model.thinking:
            num_thinking_tokens = 12288
            response = client.messages.create(
                model=model_name,
                max_tokens=num_tokens+num_thinking_tokens,
                thinking= {"type": "enabled", "budget_tokens": num_thinking_tokens},
                system=system_prompt,
                messages=messages,
                # temperature has to be set to 1 for thinking
                temperature=1.0,
            )
        else:
            response = client.messages.create(
                model=model_name,
                max_tokens=num_tokens,
                system=system_prompt,
                messages=messages,
                temperature=temperature,
            )
        end_time = time.time()
        print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
        return response

    if 'gemini' in model_name:
        start_time = time.time() 
        if len(messages)>0 and messages[0]['role'] == 'system':
            # If the first message is a system message, we need to prepend it to the user message
            system_prompt = messages[0]['content']
            messages = messages[1:]
            messages[0]['content'] = system_prompt + messages[0]['content']

        for message in messages:
            if message['role'] == 'assistant':
                message['role'] = 'model'
    
        chat = client.chats.create(
            model=model_name,
            history=[
                types.Content(role=message['role'], parts=[types.Part(text=message['content'])])
                for message in messages[:-1]
            ],
        )
        response = chat.send_message(message=messages[-1]['content'])
        end_time = time.time()
        print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
        return response

    num_tokens = 4096
    temperature = 0.7

    start_time = time.time()
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        max_tokens=num_tokens,
        temperature=temperature,
        stream=('qwq' in model_name),
    )

    if 'qwq' in model_name:
        answer_content = ""
        for chunk in response:
            if chunk.choices:
                delta = chunk.choices[0].delta
                if hasattr(delta, 'reasoning_content') and delta.reasoning_content != None:
                    # We don't need to print the reasoning content
                    pass
                else:
                    answer_content += delta.content
        response = answer_content

    end_time = time.time()
    print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
    return response