Spaces:
Sleeping
Sleeping
| import re | |
| import spaces | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from unidecode import unidecode | |
| from gradio_i18n import gettext, Translate | |
| from datasets import load_dataset | |
| from style import custom_css, solution_style, letter_style, definition_style | |
| template = """<s><|user|> | |
| Risolvi gli indizi tra parentesi per ottenere una prima lettura, e usa la chiave di lettura per ottenere la soluzione del rebus. | |
| Rebus: {rebus} | |
| Chiave di lettura: {key}<|end|> | |
| <|assistant|>""" | |
| eureka5_test_data = load_dataset( | |
| 'gsarti/eureka-rebus', 'llm_sft', | |
| data_files=["id_test.jsonl", "ood_test.jsonl"], | |
| split = "train", | |
| revision="1.0" | |
| ) | |
| OUTPUTS_BASE_URL = "https://raw.githubusercontent.com/gsarti/verbalized-rebus/main/outputs/" | |
| model_outputs = load_dataset( | |
| "csv", | |
| data_files={ | |
| "gpt4": OUTPUTS_BASE_URL + "prompted_models/gpt4o_results.csv", | |
| "claude3_5_sonnet": OUTPUTS_BASE_URL + "prompted_models/claude3_5_sonnet_results.csv", | |
| "llama3_70b": OUTPUTS_BASE_URL + "prompted_models/llama3_70b_results.csv", | |
| "qwen_72b": OUTPUTS_BASE_URL + "prompted_models/qwen_72b_results.csv", | |
| "phi3_mini": OUTPUTS_BASE_URL + "phi3_mini/phi3_mini_results_step_5070.csv", | |
| "gemma2": OUTPUTS_BASE_URL + "gemma2_2b/gemma2_2b_results_step_5070.csv", | |
| "llama3_1_8b": OUTPUTS_BASE_URL + "llama3.1_8b/llama3.1_8b_results_step_5070.csv" | |
| } | |
| ) | |
| def extract(span_text: str, tag: str = "span") -> str: | |
| pattern = rf'<{tag}[^>]*>(.*?)<\/{tag}>' | |
| matches = re.findall(pattern, span_text) | |
| return "".join(matches) if matches else "" | |
| def parse_rebus(ex_idx: int): | |
| i = eureka5_test_data[ex_idx - 1]["conversations"][0]["value"] | |
| o = eureka5_test_data[ex_idx - 1]["conversations"][1]["value"] | |
| rebus = i.split("Rebus: ")[1].split("\n")[0] | |
| rebus_letters = re.sub(r"\[.*?\]", "<<<>>>", rebus) | |
| rebus_letters = re.sub(r"([a-zA-Z]+)", rf"""{letter_style}\1</span>""", rebus_letters) | |
| fp_empty = rebus_letters.replace("<<<>>>", f"{definition_style}___</span>") | |
| key = i.split("Chiave di lettura: ")[1].split("\n")[0] | |
| key_split = key | |
| key_highlighted = re.sub(r"(\d+)", rf"""{solution_style}\1</span>""", key) | |
| fp_elements = re.findall(r"- (.*) = (.*)", o) | |
| definitions = [x[0] for x in fp_elements if x[0].startswith("[")] | |
| for i, el in enumerate(fp_elements): | |
| if el[0].startswith("["): | |
| fp_elements[i] = (re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", fp_elements[i][0]), fp_elements[i][1]) | |
| else: | |
| fp_elements[i] = ( | |
| f"{letter_style}{fp_elements[i][0]}</span>", | |
| f"{letter_style}{fp_elements[i][1]}</span>", | |
| ) | |
| fp = re.findall(r"Prima lettura: (.*)", o)[0] | |
| s_elements = re.findall(r"(\d+) = (.*)", o) | |
| s = re.findall(r"Soluzione: (.*)", o)[0] | |
| for d in definitions: | |
| rebus_letters = rebus_letters.replace("<<<>>>", d, 1) | |
| rebus_highlighted = re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", rebus_letters) | |
| return { | |
| "rebus": rebus_highlighted, | |
| "key": key_highlighted, | |
| "key_split": key_split, | |
| "fp_elements": fp_elements, | |
| "fp": fp, | |
| "fp_empty": fp_empty, | |
| "s_elements": s_elements, | |
| "s": s | |
| } | |
| #tokenizer = AutoTokenizer.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16") | |
| #model = AutoModelForCausalLM.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16") | |
| def solve_verbalized_rebus(example, history): | |
| input = template.format(input=example) | |
| #inputs = tokenizer(input, return_tensors="pt")["input_ids"] | |
| #outputs = model.generate(input_ids = inputs, max_new_tokens = 500, use_cache = True) | |
| #model_generations = tokenizer.batch_decode(outputs) | |
| #return model_generations[0] | |
| return input | |
| #demo = gr.ChatInterface(fn=solve_verbalized_rebus, examples=["Rebus: [Materiale espulso dai vulcani] R O [Strumento del calzolaio] [Si trovano ai lati del bacino] C I [Si ingrassano con la polenta] E I N [Contiene scorte di cibi] B [Isola in francese]\nChiave risolutiva: 1 ' 5 6 5 3 3 1 14"], title="Verbalized Rebus Solver") | |
| #demo.launch() | |
| with gr.Blocks(css=custom_css) as demo: | |
| lang = gr.Dropdown([("English", "en"), ("Italian", "it")], value="it", label="Select language:", interactive=True) | |
| with Translate("translations.yaml", lang, placeholder_langs=["en", "it"]): | |
| gr.Markdown(gettext("Title")) | |
| gr.Markdown(gettext("Intro")) | |
| with gr.Tab(gettext("GuessingGame")): | |
| with gr.Row(): | |
| with gr.Column(): | |
| example_id = gr.Number(1, label=gettext("CurrentExample"), minimum=1, maximum=2000, step=1, interactive=True) | |
| with gr.Column(): | |
| show_length_hints = gr.Checkbox(False, label=gettext("ShowLengthHints"), interactive=True) | |
| def show_example(example_number, show_length_hints): | |
| parsed_rebus = parse_rebus(example_number) | |
| gr.Markdown(gettext("Instructions")) | |
| gr.Markdown(gettext("Rebus") + f"{parsed_rebus['rebus']}</h4>"), | |
| gr.Markdown(gettext("Key") + f"{parsed_rebus['key']}</h4>") | |
| gr.Markdown("<br><br>") | |
| with gr.Row(): | |
| answers: list[gr.Textbox] = [] | |
| with gr.Column(scale=2): | |
| gr.Markdown(gettext("ProceedToResolution")) | |
| for el_key, el_value in parsed_rebus['fp_elements']: | |
| with gr.Row(): | |
| with gr.Column(scale=0.2, min_width=250): | |
| gr.Markdown(f"<p>{el_key} = </p>") | |
| if el_key.startswith('<span class="definition"') and show_length_hints: | |
| gr.Markdown(f"<p>({len(el_value)} lettere)</p>") | |
| with gr.Column(scale=0.2, min_width=150): | |
| if el_key.startswith('<span class="definition"'): | |
| definition_answer = gr.Textbox(show_label=False, placeholder="Guess...", interactive=True, max_lines=3) | |
| answers.append(definition_answer) | |
| else: | |
| gr.Markdown(el_value) | |
| gr.Markdown("<hr>") | |
| with gr.Column(scale=3): | |
| key_value = gr.Markdown(parsed_rebus['key_split'], visible=False) | |
| fp_empty = gr.Markdown(parsed_rebus['fp_empty'], visible=False) | |
| fp = gr.Markdown(gettext("FirstPass") + f"{parsed_rebus['fp_empty']}</h4><br>") | |
| solution_words: list[gr.Markdown] = [] | |
| clean_solution_words: list[str] = [] | |
| clean_fp = extract(fp.value) | |
| curr_idx = 0 | |
| for n_char in parsed_rebus['key_split'].split(): | |
| word = clean_fp[curr_idx:curr_idx + int(n_char)].upper() | |
| clean_solution_words.append(word) | |
| solution_word = gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>") | |
| curr_idx += int(n_char) | |
| solution_words.append(solution_word) | |
| gr.Markdown("<br>") | |
| solution = gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(clean_solution_words)}</span></h4>") | |
| correct_solution = gr.Markdown(gettext("CorrectSolution") + f"{solution_style}{parsed_rebus['s'].upper()}</span></h4>", visible=False) | |
| correct_solution_shown = gr.Checkbox(False, visible=False) | |
| gr.Markdown("<hr>") | |
| prompted_models = gr.Markdown(gettext("PromptedModels"), visible=False) | |
| gpt4_solution = gr.Markdown(gettext("GPT4Solution") + f"{solution_style}{model_outputs['gpt4'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| claude_solution = gr.Markdown(gettext("ClaudeSolution") + f"{solution_style}{model_outputs['claude3_5_sonnet'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| llama3_70b_solution = gr.Markdown(gettext("LLaMA370BSolution") + f"{solution_style}{model_outputs['llama3_70b'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| qwen_72b_solution = gr.Markdown(gettext("Qwen72BSolution") + f"{solution_style}{model_outputs['qwen_72b'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| models_separator = gr.Markdown("<hr>", visible=False) | |
| trained_models = gr.Markdown(gettext("TrainedModels"), visible=False) | |
| llama3_1_8b_solution = gr.Markdown(gettext("LLaMA318BSolution") + f"{solution_style}{model_outputs['llama3_1_8b'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| phi3_mini_solution = gr.Markdown(gettext("Phi3MiniSolution") + f"{solution_style}{model_outputs['phi3_mini'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| gemma2_solution = gr.Markdown(gettext("Gemma22BSolution") + f"{solution_style}{model_outputs['gemma2'][example_number - 1]['solution']}</span></h4>", visible=False) | |
| models_solutions_shown = gr.Checkbox(False, visible=False) | |
| with gr.Row(): | |
| btn_check = gr.Button(gettext("CheckSolution"), variant="primary") | |
| btn_show = gr.Button(gettext("ShowSolution")) | |
| btn_show_models_solutions = gr.Button(gettext("ShowModelsSolutions")) | |
| def update_fp(fp_empty=fp_empty, key_value=key_value, *answers): | |
| len_solutions = key_value.split() | |
| for answer in answers: | |
| if answer is not None and answer != "": | |
| fp_empty = fp_empty.replace("___", answer, 1) | |
| curr_idx = 0 | |
| new_solutions = [] | |
| new_solutions_clean = [] | |
| clean_fp_empty = extract(fp_empty) | |
| for n_char in len_solutions: | |
| word = clean_fp_empty[curr_idx:curr_idx + int(n_char)].upper() | |
| new_solutions_clean.append(word) | |
| new_solutions.append(gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>")) | |
| curr_idx += int(n_char) | |
| return [ | |
| gr.Markdown(gettext("FirstPass") + f"{fp_empty}</h4><br>"), | |
| gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(new_solutions_clean)}</span></h4>") | |
| ] + new_solutions | |
| def check_solution(solution, correct_solution): | |
| solution = unidecode(extract(solution)) | |
| correct_solution = unidecode(extract(correct_solution)) | |
| if solution == correct_solution: | |
| gr.Info(gettext("CorrectSolutionMsg")) | |
| else: | |
| gr.Info(gettext("IncorrectSolutionMsg")) | |
| def show_solution(correct_solution, btn_show, shown): | |
| if shown: | |
| return gr.Markdown(correct_solution, visible=False), gr.Button(gettext("ShowSolution")), gr.Checkbox(False, visible=False) | |
| else: | |
| return gr.Markdown(correct_solution, visible=True), gr.Button(gettext("HideSolution")), gr.Checkbox(True, visible=False) | |
| def show_models_solutions(models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator): | |
| if models_solutions_shown: | |
| return gr.Markdown(gpt4_solution, visible=False), gr.Markdown(claude_solution, visible=False), gr.Markdown(llama3_70b_solution, visible=False), gr.Markdown(qwen_72b_solution, visible=False), gr.Markdown(llama3_1_8b_solution, visible=False), gr.Markdown(phi3_mini_solution, visible=False), gr.Markdown(gemma2_solution, visible=False), gr.Markdown(prompted_models, visible=False), gr.Markdown(trained_models, visible=False), gr.Markdown(models_separator, visible=False), gr.Button(gettext("ShowModelsSolutions")), gr.Checkbox(False, visible=False) | |
| else: | |
| return gr.Markdown(gpt4_solution, visible=True), gr.Markdown(claude_solution, visible=True), gr.Markdown(llama3_70b_solution, visible=True), gr.Markdown(qwen_72b_solution, visible=True), gr.Markdown(llama3_1_8b_solution, visible=True), gr.Markdown(phi3_mini_solution, visible=True), gr.Markdown(gemma2_solution, visible=True), gr.Markdown(prompted_models, visible=True), gr.Markdown(trained_models, visible=True), gr.Markdown(models_separator, visible=True), gr.Button(gettext("HideModelsSolutions")), gr.Checkbox(True, visible=False) | |
| for answer in answers: | |
| answer.change(update_fp, [fp_empty, key_value, *answers], [fp, solution, *solution_words]) | |
| btn_check.click(check_solution, [solution, correct_solution], None) | |
| btn_show.click(show_solution, [correct_solution, btn_show, correct_solution_shown], [correct_solution, btn_show, correct_solution_shown]) | |
| btn_show_models_solutions.click(show_models_solutions, [models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator], [gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator, btn_show_models_solutions, models_solutions_shown]) | |
| with gr.Tab(gettext("ModelEvaluation")): | |
| gr.Markdown("<i>This section is under construction! Check again later 🙏</i>") | |
| demo.launch(show_api=False) | |