codinglabsong's picture
Update app.py
5e4de6b verified
"""
Module for loading a LoRA fine-tuned BART model and serving
an interactive Gradio interface for text generation.
"""
import torch
import gradio as gr
from transformers import AutoTokenizer
from transformers import BartForConditionalGeneration
from peft import PeftModel
def load_model() -> tuple[AutoTokenizer, PeftModel, torch.device]:
"""
Load tokenizer and LoRA-enhanced model onto available device.
Returns:
tokenizer (AutoTokenizer): Tokenizer for text processing.
model (PeftModel): Fine-tuned LoRA BART model in eval mode.
device (torch.device): Computation device (GPU if available, else CPU).
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
base_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
# Use efficient attention if desired
base_model.config.attn_implementation = "sdpa"
# Load PEFT (LoRA) model for inference
model = PeftModel.from_pretrained(
base_model, "outputs/bart-base-reddit-lora"
).eval()
model.to(device)
model.eval()
return tokenizer, model, device
# Load once at startup
tokenizer, model, device = load_model()
def predict(text: str) -> str:
"""
Generate a text response given an input prompt.
Args:
text (str): The input prompt string.
Returns:
str: The decoded model output.
"""
# Tokenize and move inputs to device
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
# Generate with both beam search and sampling for diversity
outputs = model.generate(
**inputs,
max_length=128,
num_beams=10,
do_sample=True,
length_penalty=1.2,
repetition_penalty=1.3,
no_repeat_ngram_size=3,
top_p=0.9,
temperature=0.8,
early_stopping=True,
eos_token_id=tokenizer.eos_token_id,
)
# Decode the first generated sequence
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def main() -> None:
"""
Launch Gradio web interface for interactive model inference.
"""
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=5, placeholder="Broad questions often have better results (e.g. What do you think about politics right now?).", label="Your Question"),
outputs=gr.Textbox(label="Mimic Bot's Comment"),
title="Reddit-User-Mimic-Bot Inference (Bart-LoRA)",
description="Enter a question you would ask on reddit, and our Mimic Bot would comment back! Have fun.",
allow_flagging="never",
)
interface.launch()
if __name__ == "__main__":
main()