from grammarllm.scripts.grammar_generation import generate_non_terminals, generate_grammar from grammarllm.scripts.map_terminal_tokens import generate_token_maps from grammarllm.scripts.table_parsing import parsing_table from grammarllm.modules.BaseStreamer import BaseStreamer from grammarllm.modules.PushdownAutomaton import PushdownAutomaton from grammarllm.modules.SimpleLogitProcessor import MaskLogitsProcessor import logging import re import os from collections import defaultdict from tqdm import tqdm from grammarllm.utils.common_regex import regex_dict from grammarllm.utils.examples import * from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import json import zipfile import spaces import torch from huggingface_hub import login login(token = os.getenv("llama_acces_token")) def pipeline(words, tokenizer, lhs, count=0, non_terminals=None, FINAL_RULES=None): #questa è + un preprocessing di ogni produzione nella rules """ Process input words to generate context-free grammar rules. This function implements a pipeline for creating grammar rules from a set of words or phrases. It processes the input through several stages: tokenization, state transition building, prefix grouping, non-terminal generation, and grammar rule creation. The generated rules are added to a master set of rules. Args: words (list): Collection of words or phrases to process. tokenizer: Tokenizer object used to convert words into tokens. lhs (str): Left-hand side symbol for grammar rules. count (int, optional): Counter for unique non-terminal generation. Defaults to 0, used to handle apices in NT rules. non_terminals (list, optional): Predefined non-terminals to use. FINAL_RULES (dict, optional): Existing grammar rules to extend. Returns: tuple: A tuple containing: - FINAL_RULES (dict): Updated dictionary of grammar rules. - count (int): Updated counter value for non-terminal generation. Dependencies: - build_SState: Creates state transitions from input words - group_by_prefix: Groups transitions by their prefixes - generate_non_terminals: Creates non-terminal symbols - generate_grammar: Generates grammar rules """ def build_SState(classes, tokenizer): SState = [] tokenized_classes = [tokenizer.tokenize(c) for c in classes] glob_count = 1 pbar = tqdm(total=len(classes), desc="Build state") for tok_class in tokenized_classes: state = 0 for token in tok_class: if token not in SState: #provare a togliere questo if se non necessario! SState.append((state,token,glob_count)) glob_count += 1 state += 1 pbar.update(1) pbar.close() logging.info(SState) #print(list(SState)) return SState def group_by_prefix(transitions): """Group transitions by their state and prefix""" grammar = defaultdict(list) # Build transition map for state, symbol, end in transitions: grammar[state].append((symbol, end)) # Group by state and prefix grouped = defaultdict(lambda: defaultdict(list)) for state, transitions_list in grammar.items(): for symbol, end in transitions_list: grouped[state][symbol].append((symbol, end)) return grouped tansitions = build_SState(words, tokenizer) grouped_data = group_by_prefix(tansitions) #Generate non-terminals G,S = generate_non_terminals(grouped_data,count=count) count+=1 #aggiunto x la question degli apici #tokenizer.eos_token grammar_rules = generate_grammar(G, S, NT=lhs, eos_symbol='|eot|', non_terminals_list=non_terminals) for key, values in grammar_rules.items(): if key in FINAL_RULES: FINAL_RULES[key].extend(values) else: FINAL_RULES[key] = values logging.info("\nGrouped Data:") for state, prefixes in grouped_data.items(): logging.info(f"State {state}:") for prefix, class_labels_list in prefixes.items(): logging.info(f" {prefix} -> {class_labels_list}") logging.info("\n Generated Non-Terminals:\n") for nt, prefix in G.items(): logging.info(f"{nt} -> {prefix}") logging.info("\n Ends Non-Terminals:\n") for nt, prefix in S.items(): logging.info(f"{nt} -> {prefix}") logging.info("\nGrammar Rules:\n") for nt, rules in grammar_rules.items(): for rule in rules: logging.info(f"{rule}") return FINAL_RULES,count def process_grammar_rules(productions, tokenizer):# forse è + una pipeline che poi porta alla final_rueles, infatti chiama la pipeline_for_general """ Process grammar production rules based on the specified task. This function iterates through production rules and handles them differently based on whether the task is 'Classification'/'VR' or 'General'. For general tasks, it separates rules with None tags for direct assignment and processes the rest. Args: productions (dict): Dictionary of grammar production rules tokenizer: Tokenizer to use for processing Returns: dict: Final grammar rules """ def extract_tags_and_others(rhs_list): tags_list = [] others_list = [] tag_pattern = re.compile(r'<<(.+?)>>') def smart_split(item): # Trova tutti i tag <<...>> e separa il resto del testo matches = list(tag_pattern.finditer(item)) parts = [] last_index = 0 for match in matches: # Aggiungi il testo prima del tag, splittato pre_text = item[last_index:match.start()] parts.extend(pre_text.strip().split()) # Aggiungi il tag intero come una sola unità parts.append(match.group(0)) last_index = match.end() # Aggiungi eventuale testo dopo l'ultimo tag post_text = item[last_index:] parts.extend(post_text.strip().split()) return parts for item in rhs_list: tags = [] others = [] if re.search(tag_pattern, item): words = smart_split(item) current_chunk = [] for word in words: match = re.fullmatch(tag_pattern, word) if match: tags.append(match.group(1)) # salva solo il contenuto del tag else: current_chunk.append(word) if current_chunk: others.append(' '.join(current_chunk)) else: others.append(None) tags_list.append(tags) others_list.append(others) else: tags_list.append([None]) others_list.append([item]) return tags_list, others_list final_rules = {} count = 0 for lhs, rhs_list in productions.items(): tags_list, non_terminals_list = extract_tags_and_others(rhs_list) filtered_tags = [] filtered_non_terminals = [] for j in range(len(tags_list)): tag_group = tags_list[j] non_terminal_group = non_terminals_list[j] if any(tag is not None for tag in tag_group): # Filter out None tags and add them directly to final_rules i = 0 while i < len(tag_group): if tag_group[i] is None: # Add rule directly to final_rules if lhs in final_rules: final_rules[lhs].append(rhs_list[i]) else: final_rules[lhs] = [rhs_list[i]] # Remove processed tag and non-terminal tag_group.pop(i) non_terminal_group.pop(i) else: # Keep tag and non-terminal for further processing filtered_tags.append(tag_group[i]) if i < len(non_terminal_group): filtered_non_terminals.append(non_terminal_group[i]) i += 1 else: # All tags are None, add rules directly final_rules.update({lhs: rhs_list}) #print(f"Filtered tags: {filtered_tags}") #DEBUG #print(f"Filtered non-terminals: {filtered_non_terminals}")#DEBUG # Process remaining tags through the general pipeline if filtered_tags: final_rules, count = pipeline( filtered_tags, tokenizer, lhs, count=count, non_terminals=filtered_non_terminals, FINAL_RULES=final_rules ) return final_rules, count def get_parsing_table_and_map_tt(tokenizer, productions=None, regex_dict=None): def write_grammar_to_file(grammar_rules): output_file = os.path.join('temp','grammar_rules.txt') os.makedirs(os.path.dirname(output_file), exist_ok=True) """Write grammar rules to a file""" with open(output_file, 'w') as f: for non_terminal, rules in grammar_rules.items(): for rule in rules: f.write(f"{non_terminal} -> {rule}\n") f.write("\n") logging.info(f"\nGrammar Rules to {output_file}") # Get final grammar rules final_rules, _ = process_grammar_rules(productions, tokenizer) #print(final_rules) #DEBUG write_grammar_to_file(final_rules) logging.info(final_rules) # Generate parsing table pars_tab = parsing_table(final_rules) # Generate token maps if regex_dict: map_terminal_tokens = generate_token_maps(tokenizer, pars_tab, regex_dict) else: map_terminal_tokens = generate_token_maps(tokenizer, pars_tab) logging.info("\nMap Terminal Tokens:\n") for key, values in map_terminal_tokens.items(): logging.info(f"{key} -> {values}") return pars_tab, map_terminal_tokens def generate_grammar_parameters(tokenizer, pars_tab, map_terminal_tokens): # Create Pushdown Automaton and initialize processors and streamer pda = PushdownAutomaton(grammar=pars_tab, startSymbol='S*', map=map_terminal_tokens) return MaskLogitsProcessor(tokenizer, pda), BaseStreamer(tokenizer, pda) def setup_logging(): """Setup logging configuration.""" log_dir = 'temp' os.makedirs(log_dir, exist_ok=True) # Ensure the log directory exists logging.basicConfig( filename=os.path.join(log_dir, 'GRAM-GEN.log'), level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='w+' # Overwrites the file every time ) def generate_text(model, tokenizer, text, logit_processor, streamer, max_new_tokens=400, do_sample=False, temperature=None, top_p=None, **kwargs): """ Genera testo vincolato dalla grammatica, con configurazione dei parametri di generazione sicura. Args: model: Il modello pre-addestrato. tokenizer: Il tokenizer del modello. text: Input text iniziale. logit_processor: Processor dei logit basato sulla grammatica. streamer: Streamer per l'output live. max_new_tokens: Numero massimo di nuovi token da generare. do_sample: Se True, abilita la generazione stocastica. temperature: Controlla la casualità (usato solo se do_sample=True). top_p: Top-p (nucleus sampling), usato solo se do_sample=True. **kwargs: Parametri aggiuntivi opzionali per model.generate(). """ try: tokenized_input = tokenizer(text, return_tensors="pt") # Safe defaults kwargs.setdefault("num_beams", 1) # beam search disattivato kwargs.setdefault("pad_token_id", tokenizer.eos_token_id) # Sicurezza num_beams if kwargs["num_beams"] != 1: logging.warning("⚠️ num_beams > 1 non è compatibile con la generazione vincolata da grammatica. Impostato automaticamente a num_beams=1.") kwargs["num_beams"] = 1 # Sampling parameters if do_sample: if temperature is not None: kwargs["temperature"] = temperature if top_p is not None: kwargs["top_p"] = top_p else: # Rimuovi parametri di sampling se presenti kwargs.pop("temperature", None) kwargs.pop("top_p", None) # Device compatibility device = model.device input_ids = tokenized_input["input_ids"].to(device) if input_ids.device != model.device: logging.warning("Errore: gli 'input_ids' sono sulla device {input_ids.device}, mentre il modello è sulla device {model.device}. Spostando 'input_ids' sulla stessa device del modello.") attention_mask = tokenized_input["attention_mask"].to(device) if attention_mask.device != model.device: logging.warning(f"Errore: l'attention_mask è sulla device {attention_mask.device}, mentre il modello è sulla device {model.device}. Spostando 'attention_mask' sulla stessa device del modello.") start = input_ids.shape[1] output = model.generate( input_ids=input_ids, attention_mask=attention_mask, do_sample=do_sample, max_new_tokens=max_new_tokens, streamer=streamer, logits_processor=[logit_processor], **kwargs ) answer = tokenizer.decode(output[0][start:], skip_special_tokens=True) return answer except Exception as e: raise RuntimeError(f"Errore nella generazione del testo: {e}") @spaces.GPU def run_grammarllm(prompt, productions_json, model_choice,regex_json): setup_logging() # Parsing productions try: productions = json.loads(productions_json) except json.JSONDecodeError: return "Errore: JSON productions non valido.", None # Regex fissa, non caricata dall'utente regex_raw = { "regex_alfanum": "[a-zA-Z0-9]+", "regex_letters": "[a-zA-Z]+", "regex_number": "\\d+", "regex_decimal": "\\d+([.,]\\d+)?", "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*", "regex_)": "\\)", "regex_(": "\\(" } try: regex_raw = json.loads(regex_json) regex_dict = {key: re.compile(pattern) for key, pattern in regex_raw.items()} except (json.JSONDecodeError, re.error) as e: return f"Errore nelle regex personalizzate: {str(e)}", None try: # Selezione del modello basata sulla scelta dell'utente if model_choice == "GPT-2": model_name = "gpt2" elif model_choice == "Llama 3.2 1B": model_name = "meta-llama/Llama-3.2-1B-Instruct" #elif model_choice == "Llama 3.1 8B": # model_name = "meta-llama/Llama-3.1-8B-Instruct" else: return f"Modello non supportato: {model_choice}", None # Caricamento del tokenizer e del modello print(f"Caricamento del modello: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) # Configurazione del device e dtype per ottimizzare le prestazioni device = "cuda" if torch.cuda.is_available() else "cpu" if model_choice.startswith("Llama"): # Per i modelli Llama, usa torch_dtype=torch.float16 per risparmiare memoria model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) else: # Per GPT-2 model = AutoModelForCausalLM.from_pretrained(model_name) model = model.to(device) # Aggiungi pad_token se non esiste if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token pars_table, map_terminal_tokens = get_parsing_table_and_map_tt( tokenizer, productions=productions, regex_dict=regex_dict, ) LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens) output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer) # Creazione del file ZIP temp_dir = "./temp" zip_path = temp_dir + ".zip" # Assicurati che temp_dir esista if os.path.exists(temp_dir): with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: for root, dirs, files in os.walk(temp_dir): for file in files: file_path = os.path.join(root, file) arcname = os.path.relpath(file_path, temp_dir) zipf.write(file_path, arcname) else: zip_path = None # Libera la memoria del modello del model torch.cuda.empty_cache() if torch.cuda.is_available() else None return output, zip_path except Exception as e: return f"Errore durante l'inferenza: {str(e)}", None default_grammars = { "HC Grammar": json.dumps({ "S*": ["<> A", "<> B", "<> C"], "A": ["<> D", "<> E", "<> F"], "B": ["<>", "<>", "<>"], "C": ["<>", "<>", "<>"], "D": ["<>"], "E": ["<>"], "F": ["<>"] }, indent=4), "VR Grammar": json.dumps({ "S*": ["<> S*", "<> S*", "<> S*"], }, indent=4), "General Grammar": json.dumps({ 'S*': ["( LETTERS )"], 'LETTERS': ['letters number LETTERS',"ε"] }, indent=4), } def update_productions(grammar_choice): # Aggiorna textbox productions al cambio preset return default_grammars[grammar_choice] def load_file(file_obj): if file_obj is None: return "Errore: nessun file caricato." try: # In newer Gradio versions, file_obj is a path string, not a file object if isinstance(file_obj, str): # file_obj is the file path with open(file_obj, 'r', encoding='utf-8') as f: content = f.read() else: # Fallback for older Gradio versions or different file object types if hasattr(file_obj, 'name'): # file_obj has a 'name' attribute containing the path with open(file_obj.name, 'r', encoding='utf-8') as f: content = f.read() else: # Try to read directly (old behavior) content = file_obj.read().decode("utf-8") json.loads(content) # controlla che sia JSON valido return content except Exception as e: return f"Errore nel caricamento file: {str(e)}" # Interfaccia Gradio migliorata with gr.Blocks(title="GrammarLLM - enable structured generation via formal language") as demo: gr.Markdown("# GrammarLLM - enable structured generation via LLprefix") gr.Markdown("") with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Insert your prompt", placeholder="Type here your prompt...", lines=3 ) with gr.Column(scale=1): model_choice = gr.Dropdown( choices=["GPT-2", "Llama 3.2 1B"],#, "Llama 3.1 8B"], label="Choose the model", value="GPT-2", interactive=True ) with gr.Row(): with gr.Column(): grammar_choice = gr.Dropdown( list(default_grammars.keys()), label="Choose Productions (JSON)", value="HC Grammar", interactive=True, elem_id="grammar_choice" ) with gr.Column(): productions_upload = gr.File( label="Upload file Productions (JSON)", file_types=['.json'] ) productions_text = gr.Textbox( label="Productions (JSON)", lines=15, value=default_grammars["HC Grammar"], info="Type your here your grammar in json fromat" ) regex_text = gr.Textbox( label="Regex to define Terminals (JSON)", lines=10, value=json.dumps({ "regex_alfanum": "[a-zA-Z0-9]+", "regex_letters": "[a-zA-Z]+", "regex_number": "\\d+", "regex_decimal": "\\d+([.,]\\d+)?", "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*", "regex_)": "\\)", "regex_(": "\\(" }, indent=4), info="Modify these common regex" ) with gr.Row(): submit_btn = gr.Button("🚀 Generate Output", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Clean", variant="secondary") with gr.Row(): with gr.Column(): output_text = gr.Textbox( label="Output generated", lines=10, show_copy_button=True ) with gr.Column(): zip_file = gr.File(label="📦 Download ZIP (if available)") with gr.Accordion("ℹ️ About GrammarLLM and LLprefix", open=False): gr.Markdown(""" ### 📚 What is GrammarLLM? GrammarLLM enables structured text generation constrained by a formal grammar, using LLMs (Large Language Models) such as GPT-2 or LLaMA. ### 🔍 What you can do: - **Hierarchical classification**: Define class hierarchies, as shown in the "HC Grammar" example. - **Vocabulary restriction**: Specify a limited set of valid words to be used. Including examples in the prompt is highly recommended to improve output quality. - **Constrained generation**: Use LLprefix to define any regular or context-free grammar in JSON format. 📄 For more details about LLprefix and the underlying algorithms, refer to the official paper. """) # Callback: quando cambio dropdown, aggiorno productions_text grammar_choice.change( fn=update_productions, inputs=grammar_choice, outputs=productions_text, ) # Callback: quando carico file productions, aggiorno productions_text (override dropdown) productions_upload.upload( fn=load_file, inputs=productions_upload, outputs=productions_text, ) # Al submit del form chiamo run_grammarllm submit_btn.click( fn=run_grammarllm, inputs=[prompt_input, productions_text, model_choice, regex_text], outputs=[output_text, zip_file], show_progress=True ) # Funzione per pulire i campi def clear_fields(): return "", default_grammars["HC"], "", None, json.dumps({ "regex_alfanum": "[a-zA-Z0-9]+", "regex_letters": "[a-zA-Z]+", "regex_number": "\\d+", "regex_decimal": "\\d+([.,]\\d+)?", "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*", "regex_)": "\\)", "regex_(": "\\(" }, indent=4) clear_btn.click( fn=clear_fields, outputs=[prompt_input, productions_text, output_text, zip_file, regex_text] ) if __name__ == "__main__": demo.launch()