import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load the mT5-small model and tokenizer | |
model_name = "google/mt5-small" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Define the chatbot function for summarization and answering questions | |
def chatbot(user_input): | |
# Tokenize the user input | |
inputs = tokenizer(user_input, return_tensors="pt", max_length=512, truncation=True) | |
# Generate a response (you can customize max_length and num_beams for different outputs) | |
outputs = model.generate(inputs["input_ids"], max_length=150, num_beams=2, early_stopping=True) | |
# Decode and return the generated text | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Set up the Gradio interface | |
demo = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="mT5-Small Chatbot") | |
# Launch the app | |
demo.launch() | |