FineTuningModel / app.py
KB-Infinity-Tech's picture
Update app.py
433126f verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# -----------------------------
# LOAD MODELS
# -----------------------------
BASE_MODEL = "google/flan-t5-small"
FINETUNED_MODEL = "KB-Infinity-Tech/t5-samsum-mini"
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
fine_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
fine_model = AutoModelForSeq2SeqLM.from_pretrained(FINETUNED_MODEL)
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model.to(device)
fine_model.to(device)
# -----------------------------
# GENERATE FUNCTION
# -----------------------------
def summarize(text):
prompt = "summarize: " + text
# Base model
inputs = base_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
outputs = base_model.generate(**inputs, max_new_tokens=60)
base_summary = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Fine-tuned model
inputs2 = fine_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
outputs2 = fine_model.generate(**inputs2, max_new_tokens=60)
fine_summary = fine_tokenizer.decode(outputs2[0], skip_special_tokens=True)
return base_summary, fine_summary
# -----------------------------
# UI
# -----------------------------
iface = gr.Interface(
fn=summarize,
inputs=gr.Textbox(lines=8, label="Dialogue"),
outputs=[
gr.Textbox(label="Base Model (FLAN-T5)"),
gr.Textbox(label="Fine-Tuned Model"),
],
title="🧠 T5 Summarization Compare",
description="Compare base FLAN-T5 vs your fine-tuned SAMSum model",
)
if __name__ == "__main__":
iface.launch()