LLM-As-Chatbot / scripts /hparams_explore.py
chansung's picture
update
88f55d9
import time
import itertools
import wandb
from transformers import GenerationConfig
wandb.login(key="")
PROJECT="txt_gen_test_project"
generation_configs = {
"temperature": [0.5, 0.7, 0.8, 0.9, 1.0],
"top_p": [0.5, 0.75, 0.85, 0.95, 1.0],
"num_beams": [1, 2, 3, 4]
}
num_gens = 1
# token initialization
# model initialization
for comb in itertools.product(generation_configs['temperature'],
generation_configs['top_p'],
generation_configs['num_beams']):
temperature = comb[0]
top_p = comb[1]
num_beams = comb[2]
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
)
first_columns = [f"gen_txt_{num}" for num in range(num_gens)]
columns = first_columns + ["temperature", "top_p", "num_beams", "time_delta"]
avg_time_delta = 0
txt_gens = []
for i in range(num_gens):
start = time.time()
# text generation
text = "dummy text"
txt_gens.append(text)
# decode outputs
end = time.time()
t_delta = end - start
avg_time_delta = avg_time_delta + t_delta
avg_time_delta = round(avg_time_delta / num_gens, 4)
wandb.init(
project=PROJECT,
name=f"t@{temperature}-tp@{top_p}-nb@{num_beams}",
config=generation_config,
)
text_table = wandb.Table(columns=columns)
text_table.add_data(*txt_gens, temperature, top_p, num_beams, avg_time_delta)
wandb.log({
"avg_t_delta": avg_time_delta,
"results": text_table
})
wandb.finish()