gsarti's picture
Pin eureka-rebus version to 1.0
9ee78af
import re
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from unidecode import unidecode
from gradio_i18n import gettext, Translate
from datasets import load_dataset
from style import custom_css, solution_style, letter_style, definition_style
template = """<s><|user|>
Risolvi gli indizi tra parentesi per ottenere una prima lettura, e usa la chiave di lettura per ottenere la soluzione del rebus.
Rebus: {rebus}
Chiave di lettura: {key}<|end|>
<|assistant|>"""
eureka5_test_data = load_dataset(
'gsarti/eureka-rebus', 'llm_sft',
data_files=["id_test.jsonl", "ood_test.jsonl"],
split = "train",
revision="1.0"
)
OUTPUTS_BASE_URL = "https://raw.githubusercontent.com/gsarti/verbalized-rebus/main/outputs/"
model_outputs = load_dataset(
"csv",
data_files={
"gpt4": OUTPUTS_BASE_URL + "prompted_models/gpt4o_results.csv",
"claude3_5_sonnet": OUTPUTS_BASE_URL + "prompted_models/claude3_5_sonnet_results.csv",
"llama3_70b": OUTPUTS_BASE_URL + "prompted_models/llama3_70b_results.csv",
"qwen_72b": OUTPUTS_BASE_URL + "prompted_models/qwen_72b_results.csv",
"phi3_mini": OUTPUTS_BASE_URL + "phi3_mini/phi3_mini_results_step_5070.csv",
"gemma2": OUTPUTS_BASE_URL + "gemma2_2b/gemma2_2b_results_step_5070.csv",
"llama3_1_8b": OUTPUTS_BASE_URL + "llama3.1_8b/llama3.1_8b_results_step_5070.csv"
}
)
def extract(span_text: str, tag: str = "span") -> str:
pattern = rf'<{tag}[^>]*>(.*?)<\/{tag}>'
matches = re.findall(pattern, span_text)
return "".join(matches) if matches else ""
def parse_rebus(ex_idx: int):
i = eureka5_test_data[ex_idx - 1]["conversations"][0]["value"]
o = eureka5_test_data[ex_idx - 1]["conversations"][1]["value"]
rebus = i.split("Rebus: ")[1].split("\n")[0]
rebus_letters = re.sub(r"\[.*?\]", "<<<>>>", rebus)
rebus_letters = re.sub(r"([a-zA-Z]+)", rf"""{letter_style}\1</span>""", rebus_letters)
fp_empty = rebus_letters.replace("<<<>>>", f"{definition_style}___</span>")
key = i.split("Chiave di lettura: ")[1].split("\n")[0]
key_split = key
key_highlighted = re.sub(r"(\d+)", rf"""{solution_style}\1</span>""", key)
fp_elements = re.findall(r"- (.*) = (.*)", o)
definitions = [x[0] for x in fp_elements if x[0].startswith("[")]
for i, el in enumerate(fp_elements):
if el[0].startswith("["):
fp_elements[i] = (re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", fp_elements[i][0]), fp_elements[i][1])
else:
fp_elements[i] = (
f"{letter_style}{fp_elements[i][0]}</span>",
f"{letter_style}{fp_elements[i][1]}</span>",
)
fp = re.findall(r"Prima lettura: (.*)", o)[0]
s_elements = re.findall(r"(\d+) = (.*)", o)
s = re.findall(r"Soluzione: (.*)", o)[0]
for d in definitions:
rebus_letters = rebus_letters.replace("<<<>>>", d, 1)
rebus_highlighted = re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", rebus_letters)
return {
"rebus": rebus_highlighted,
"key": key_highlighted,
"key_split": key_split,
"fp_elements": fp_elements,
"fp": fp,
"fp_empty": fp_empty,
"s_elements": s_elements,
"s": s
}
#tokenizer = AutoTokenizer.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16")
#model = AutoModelForCausalLM.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16")
@spaces.GPU
def solve_verbalized_rebus(example, history):
input = template.format(input=example)
#inputs = tokenizer(input, return_tensors="pt")["input_ids"]
#outputs = model.generate(input_ids = inputs, max_new_tokens = 500, use_cache = True)
#model_generations = tokenizer.batch_decode(outputs)
#return model_generations[0]
return input
#demo = gr.ChatInterface(fn=solve_verbalized_rebus, examples=["Rebus: [Materiale espulso dai vulcani] R O [Strumento del calzolaio] [Si trovano ai lati del bacino] C I [Si ingrassano con la polenta] E I N [Contiene scorte di cibi] B [Isola in francese]\nChiave risolutiva: 1 ' 5 6 5 3 3 1 14"], title="Verbalized Rebus Solver")
#demo.launch()
with gr.Blocks(css=custom_css) as demo:
lang = gr.Dropdown([("English", "en"), ("Italian", "it")], value="it", label="Select language:", interactive=True)
with Translate("translations.yaml", lang, placeholder_langs=["en", "it"]):
gr.Markdown(gettext("Title"))
gr.Markdown(gettext("Intro"))
with gr.Tab(gettext("GuessingGame")):
with gr.Row():
with gr.Column():
example_id = gr.Number(1, label=gettext("CurrentExample"), minimum=1, maximum=2000, step=1, interactive=True)
with gr.Column():
show_length_hints = gr.Checkbox(False, label=gettext("ShowLengthHints"), interactive=True)
@gr.render(inputs=[example_id, show_length_hints], triggers=[demo.load, example_id.change, show_length_hints.change, lang.change])
def show_example(example_number, show_length_hints):
parsed_rebus = parse_rebus(example_number)
gr.Markdown(gettext("Instructions"))
gr.Markdown(gettext("Rebus") + f"{parsed_rebus['rebus']}</h4>"),
gr.Markdown(gettext("Key") + f"{parsed_rebus['key']}</h4>")
gr.Markdown("<br><br>")
with gr.Row():
answers: list[gr.Textbox] = []
with gr.Column(scale=2):
gr.Markdown(gettext("ProceedToResolution"))
for el_key, el_value in parsed_rebus['fp_elements']:
with gr.Row():
with gr.Column(scale=0.2, min_width=250):
gr.Markdown(f"<p>{el_key} = </p>")
if el_key.startswith('<span class="definition"') and show_length_hints:
gr.Markdown(f"<p>({len(el_value)} lettere)</p>")
with gr.Column(scale=0.2, min_width=150):
if el_key.startswith('<span class="definition"'):
definition_answer = gr.Textbox(show_label=False, placeholder="Guess...", interactive=True, max_lines=3)
answers.append(definition_answer)
else:
gr.Markdown(el_value)
gr.Markdown("<hr>")
with gr.Column(scale=3):
key_value = gr.Markdown(parsed_rebus['key_split'], visible=False)
fp_empty = gr.Markdown(parsed_rebus['fp_empty'], visible=False)
fp = gr.Markdown(gettext("FirstPass") + f"{parsed_rebus['fp_empty']}</h4><br>")
solution_words: list[gr.Markdown] = []
clean_solution_words: list[str] = []
clean_fp = extract(fp.value)
curr_idx = 0
for n_char in parsed_rebus['key_split'].split():
word = clean_fp[curr_idx:curr_idx + int(n_char)].upper()
clean_solution_words.append(word)
solution_word = gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>")
curr_idx += int(n_char)
solution_words.append(solution_word)
gr.Markdown("<br>")
solution = gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(clean_solution_words)}</span></h4>")
correct_solution = gr.Markdown(gettext("CorrectSolution") + f"{solution_style}{parsed_rebus['s'].upper()}</span></h4>", visible=False)
correct_solution_shown = gr.Checkbox(False, visible=False)
gr.Markdown("<hr>")
prompted_models = gr.Markdown(gettext("PromptedModels"), visible=False)
gpt4_solution = gr.Markdown(gettext("GPT4Solution") + f"{solution_style}{model_outputs['gpt4'][example_number - 1]['solution']}</span></h4>", visible=False)
claude_solution = gr.Markdown(gettext("ClaudeSolution") + f"{solution_style}{model_outputs['claude3_5_sonnet'][example_number - 1]['solution']}</span></h4>", visible=False)
llama3_70b_solution = gr.Markdown(gettext("LLaMA370BSolution") + f"{solution_style}{model_outputs['llama3_70b'][example_number - 1]['solution']}</span></h4>", visible=False)
qwen_72b_solution = gr.Markdown(gettext("Qwen72BSolution") + f"{solution_style}{model_outputs['qwen_72b'][example_number - 1]['solution']}</span></h4>", visible=False)
models_separator = gr.Markdown("<hr>", visible=False)
trained_models = gr.Markdown(gettext("TrainedModels"), visible=False)
llama3_1_8b_solution = gr.Markdown(gettext("LLaMA318BSolution") + f"{solution_style}{model_outputs['llama3_1_8b'][example_number - 1]['solution']}</span></h4>", visible=False)
phi3_mini_solution = gr.Markdown(gettext("Phi3MiniSolution") + f"{solution_style}{model_outputs['phi3_mini'][example_number - 1]['solution']}</span></h4>", visible=False)
gemma2_solution = gr.Markdown(gettext("Gemma22BSolution") + f"{solution_style}{model_outputs['gemma2'][example_number - 1]['solution']}</span></h4>", visible=False)
models_solutions_shown = gr.Checkbox(False, visible=False)
with gr.Row():
btn_check = gr.Button(gettext("CheckSolution"), variant="primary")
btn_show = gr.Button(gettext("ShowSolution"))
btn_show_models_solutions = gr.Button(gettext("ShowModelsSolutions"))
def update_fp(fp_empty=fp_empty, key_value=key_value, *answers):
len_solutions = key_value.split()
for answer in answers:
if answer is not None and answer != "":
fp_empty = fp_empty.replace("___", answer, 1)
curr_idx = 0
new_solutions = []
new_solutions_clean = []
clean_fp_empty = extract(fp_empty)
for n_char in len_solutions:
word = clean_fp_empty[curr_idx:curr_idx + int(n_char)].upper()
new_solutions_clean.append(word)
new_solutions.append(gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>"))
curr_idx += int(n_char)
return [
gr.Markdown(gettext("FirstPass") + f"{fp_empty}</h4><br>"),
gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(new_solutions_clean)}</span></h4>")
] + new_solutions
def check_solution(solution, correct_solution):
solution = unidecode(extract(solution))
correct_solution = unidecode(extract(correct_solution))
if solution == correct_solution:
gr.Info(gettext("CorrectSolutionMsg"))
else:
gr.Info(gettext("IncorrectSolutionMsg"))
def show_solution(correct_solution, btn_show, shown):
if shown:
return gr.Markdown(correct_solution, visible=False), gr.Button(gettext("ShowSolution")), gr.Checkbox(False, visible=False)
else:
return gr.Markdown(correct_solution, visible=True), gr.Button(gettext("HideSolution")), gr.Checkbox(True, visible=False)
def show_models_solutions(models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator):
if models_solutions_shown:
return gr.Markdown(gpt4_solution, visible=False), gr.Markdown(claude_solution, visible=False), gr.Markdown(llama3_70b_solution, visible=False), gr.Markdown(qwen_72b_solution, visible=False), gr.Markdown(llama3_1_8b_solution, visible=False), gr.Markdown(phi3_mini_solution, visible=False), gr.Markdown(gemma2_solution, visible=False), gr.Markdown(prompted_models, visible=False), gr.Markdown(trained_models, visible=False), gr.Markdown(models_separator, visible=False), gr.Button(gettext("ShowModelsSolutions")), gr.Checkbox(False, visible=False)
else:
return gr.Markdown(gpt4_solution, visible=True), gr.Markdown(claude_solution, visible=True), gr.Markdown(llama3_70b_solution, visible=True), gr.Markdown(qwen_72b_solution, visible=True), gr.Markdown(llama3_1_8b_solution, visible=True), gr.Markdown(phi3_mini_solution, visible=True), gr.Markdown(gemma2_solution, visible=True), gr.Markdown(prompted_models, visible=True), gr.Markdown(trained_models, visible=True), gr.Markdown(models_separator, visible=True), gr.Button(gettext("HideModelsSolutions")), gr.Checkbox(True, visible=False)
for answer in answers:
answer.change(update_fp, [fp_empty, key_value, *answers], [fp, solution, *solution_words])
btn_check.click(check_solution, [solution, correct_solution], None)
btn_show.click(show_solution, [correct_solution, btn_show, correct_solution_shown], [correct_solution, btn_show, correct_solution_shown])
btn_show_models_solutions.click(show_models_solutions, [models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator], [gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator, btn_show_models_solutions, models_solutions_shown])
with gr.Tab(gettext("ModelEvaluation")):
gr.Markdown("<i>This section is under construction! Check again later 🙏</i>")
demo.launch(show_api=False)