""" Module for loading a LoRA fine-tuned BART model and serving an interactive Gradio interface for text generation. """ import torch import gradio as gr from transformers import AutoTokenizer from transformers import BartForConditionalGeneration from peft import PeftModel def load_model() -> tuple[AutoTokenizer, PeftModel, torch.device]: """ Load tokenizer and LoRA-enhanced model onto available device. Returns: tokenizer (AutoTokenizer): Tokenizer for text processing. model (PeftModel): Fine-tuned LoRA BART model in eval mode. device (torch.device): Computation device (GPU if available, else CPU). """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer and base model tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") base_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") # Use efficient attention if desired base_model.config.attn_implementation = "sdpa" # Load PEFT (LoRA) model for inference model = PeftModel.from_pretrained( base_model, "outputs/bart-base-reddit-lora" ).eval() model.to(device) model.eval() return tokenizer, model, device # Load once at startup tokenizer, model, device = load_model() def predict(text: str) -> str: """ Generate a text response given an input prompt. Args: text (str): The input prompt string. Returns: str: The decoded model output. """ # Tokenize and move inputs to device inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, ).to(device) # Generate with both beam search and sampling for diversity outputs = model.generate( **inputs, max_length=128, num_beams=10, do_sample=True, length_penalty=1.2, repetition_penalty=1.3, no_repeat_ngram_size=3, top_p=0.9, temperature=0.8, early_stopping=True, eos_token_id=tokenizer.eos_token_id, ) # Decode the first generated sequence return tokenizer.decode(outputs[0], skip_special_tokens=True) def main() -> None: """ Launch Gradio web interface for interactive model inference. """ interface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=5, placeholder="Broad questions often have better results (e.g. What do you think about politics right now?).", label="Your Question"), outputs=gr.Textbox(label="Mimic Bot's Comment"), title="Reddit-User-Mimic-Bot Inference (Bart-LoRA)", description="Enter a question you would ask on reddit, and our Mimic Bot would comment back! Have fun.", allow_flagging="never", ) interface.launch() if __name__ == "__main__": main()