CodePDE / llm_api.py
LDA1020's picture
feat: code release
56c4b9b verified
from anthropic import Anthropic
from google import genai
from google.genai import types
from openai import OpenAI
import time
def get_client(messages, cfg):
if 'gpt' in cfg.model.family_name or cfg.model.family_name == 'o':
client = OpenAI(api_key=cfg.model.api_key)
elif 'claude' in cfg.model.family_name:
client = Anthropic(api_key=cfg.model.api_key)
elif 'deepseek' in cfg.model.family_name:
client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url)
elif 'gemini' in cfg.model.family_name:
client = genai.Client(api_key=cfg.model.api_key)
elif cfg.model.family_name == 'qwen':
client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url)
else:
raise ValueError(f'Model {cfg.model.family_name} not recognized')
return client
def generate_response(messages, cfg):
client = get_client(messages, cfg)
model_name = cfg.model.name
if 'o1' in model_name or 'o3' in model_name or 'o4' in model_name:
# Need to follow the restrictions
if 'o1' in model_name and len(messages)>0 and messages[0]['role'] == 'system':
system_prompt = messages[0]['content']
messages = messages[1:]
messages[0]['content'] = system_prompt + messages[0]['content']
# TODO: add these to the hydra config
num_tokens = 16384
temperature = 1.0
start_time = time.time()
response = client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=num_tokens,
temperature=temperature)
end_time = time.time()
print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
return response
if 'claude' in model_name:
num_tokens = 8192 # Give claude more tokens
temperature = 0.7
if len(messages)>0 and messages[0]['role'] == 'system':
system_prompt = messages[0]['content']
messages = messages[1:]
start_time = time.time()
if cfg.model.thinking:
num_thinking_tokens = 12288
response = client.messages.create(
model=model_name,
max_tokens=num_tokens+num_thinking_tokens,
thinking= {"type": "enabled", "budget_tokens": num_thinking_tokens},
system=system_prompt,
messages=messages,
# temperature has to be set to 1 for thinking
temperature=1.0,
)
else:
response = client.messages.create(
model=model_name,
max_tokens=num_tokens,
system=system_prompt,
messages=messages,
temperature=temperature,
)
end_time = time.time()
print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
return response
if 'gemini' in model_name:
start_time = time.time()
if len(messages)>0 and messages[0]['role'] == 'system':
# If the first message is a system message, we need to prepend it to the user message
system_prompt = messages[0]['content']
messages = messages[1:]
messages[0]['content'] = system_prompt + messages[0]['content']
for message in messages:
if message['role'] == 'assistant':
message['role'] = 'model'
chat = client.chats.create(
model=model_name,
history=[
types.Content(role=message['role'], parts=[types.Part(text=message['content'])])
for message in messages[:-1]
],
)
response = chat.send_message(message=messages[-1]['content'])
end_time = time.time()
print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
return response
num_tokens = 4096
temperature = 0.7
start_time = time.time()
response = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=num_tokens,
temperature=temperature,
stream=('qwq' in model_name),
)
if 'qwq' in model_name:
answer_content = ""
for chunk in response:
if chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, 'reasoning_content') and delta.reasoning_content != None:
# We don't need to print the reasoning content
pass
else:
answer_content += delta.content
response = answer_content
end_time = time.time()
print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.')
return response