File size: 4,076 Bytes
cee775f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig

# List of summarization models
model_names = [
    "google/bigbird-pegasus-large-arxiv",
    "facebook/bart-large-cnn",
    "google/t5-v1_1-large",
    "sshleifer/distilbart-cnn-12-6",
    "allenai/led-base-16384",
    "google/pegasus-xsum",
    "togethercomputer/LLaMA-2-7B-32K"
]

# Placeholder for the summarizer pipeline, tokenizer, and maximum tokens
summarizer = None
tokenizer = None
max_tokens = None


# Function to load the selected model
def load_model(model_name):
    global summarizer, tokenizer, max_tokens
    try:
        # Load the summarization pipeline with the selected model
        summarizer = pipeline("summarization", model=model_name, torch_dtype=torch.bfloat16)
        # Load the tokenizer for the selected model
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Load the configuration for the selected model
        config = AutoConfig.from_pretrained(model_name)

        # Determine the maximum tokens based on available configuration attributes
        if hasattr(config, 'max_position_embeddings'):
            max_tokens = config.max_position_embeddings
        elif hasattr(config, 'n_positions'):
            max_tokens = config.n_positions
        elif hasattr(config, 'd_model'):
            max_tokens = config.d_model  # for T5 models, d_model is a rough proxy
        else:
            max_tokens = "Unknown"

        return f"Model {model_name} loaded successfully! Max tokens: {max_tokens}"
    except Exception as e:
        return f"Failed to load model {model_name}. Error: {str(e)}"


# Function to summarize the input text
def summarize_text(input, min_length, max_length):
    if summarizer is None:
        return "No model loaded!"

    # Tokenize the input text and check the number of tokens
    input_tokens = tokenizer.encode(input, return_tensors="pt")
    num_tokens = input_tokens.shape[1]
    if num_tokens > max_tokens:
        # Return an error message if the input text exceeds the maximum token limit
        return f"Error: The input text has {num_tokens} tokens, which exceeds the maximum allowed {max_tokens} tokens. Please enter shorter text."

    # Calculate minimum and maximum summary length based on the percentages
    min_summary_length = int(num_tokens * (min_length / 100))
    max_summary_length = int(num_tokens * (max_length / 100))

    # Summarize the input text using the loaded model with specified lengths
    output = summarizer(input, min_length=min_summary_length, max_length=max_summary_length)
    return output[0]['summary_text']


# Gradio Interface
with gr.Blocks() as demo:
    with gr.Row():
        # Dropdown menu for selecting the model
        model_dropdown = gr.Dropdown(choices=model_names, label="Choose a model", value="sshleifer/distilbart-cnn-12-6")
        # Button to load the selected model
        load_button = gr.Button("Load Model")

    # Textbox to display the load status
    load_message = gr.Textbox(label="Load Status", interactive=False)

    # Slider for minimum summary length
    min_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Minimum Summary Length (%)", value=10)
    # Slider for maximum summary length
    max_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Maximum Summary Length (%)", value=20)

    # Textbox for inputting the text to be summarized
    input_text = gr.Textbox(label="Input text to summarize", lines=6)
    # Button to trigger the summarization
    summarize_button = gr.Button("Summarize Text")
    # Textbox to display the summarized text
    output_text = gr.Textbox(label="Summarized text", lines=4)

    # Define the actions for the load button and summarize button
    load_button.click(fn=load_model, inputs=model_dropdown, outputs=load_message)
    summarize_button.click(fn=summarize_text, inputs=[input_text, min_length_slider, max_length_slider],
                           outputs=output_text)

# Launch the Gradio interface
demo.launch()