Spaces:
Runtime error
Runtime error
import time | |
import itertools | |
import wandb | |
from transformers import GenerationConfig | |
wandb.login(key="") | |
PROJECT="txt_gen_test_project" | |
generation_configs = { | |
"temperature": [0.5, 0.7, 0.8, 0.9, 1.0], | |
"top_p": [0.5, 0.75, 0.85, 0.95, 1.0], | |
"num_beams": [1, 2, 3, 4] | |
} | |
num_gens = 1 | |
# token initialization | |
# model initialization | |
for comb in itertools.product(generation_configs['temperature'], | |
generation_configs['top_p'], | |
generation_configs['num_beams']): | |
temperature = comb[0] | |
top_p = comb[1] | |
num_beams = comb[2] | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
num_beams=num_beams, | |
) | |
first_columns = [f"gen_txt_{num}" for num in range(num_gens)] | |
columns = first_columns + ["temperature", "top_p", "num_beams", "time_delta"] | |
avg_time_delta = 0 | |
txt_gens = [] | |
for i in range(num_gens): | |
start = time.time() | |
# text generation | |
text = "dummy text" | |
txt_gens.append(text) | |
# decode outputs | |
end = time.time() | |
t_delta = end - start | |
avg_time_delta = avg_time_delta + t_delta | |
avg_time_delta = round(avg_time_delta / num_gens, 4) | |
wandb.init( | |
project=PROJECT, | |
name=f"t@{temperature}-tp@{top_p}-nb@{num_beams}", | |
config=generation_config, | |
) | |
text_table = wandb.Table(columns=columns) | |
text_table.add_data(*txt_gens, temperature, top_p, num_beams, avg_time_delta) | |
wandb.log({ | |
"avg_t_delta": avg_time_delta, | |
"results": text_table | |
}) | |
wandb.finish() | |