import gradio as gr
from transformers import pipeline, AutoTokenizer
# Load the Hugging Face model
model_path = "patrixtano/mt5-base-anaphora_czech_6e"
model_pipeline = pipeline("text2text-generation", model=model_path)
tokenizer = AutoTokenizer.from_pretrained("patrixtano/mt5-base-anaphora_czech_6e")
def predict(text_input):
"""
Generate a prediction for the given input text using the Hugging Face model.
"""
input_length = len(tokenizer(text_input)["input_ids"])
generation_parameters = {
"min_length": input_length + 5, # Set your desired minimum length
"max_length": input_length + 10 # Set your desired maximum length
}
try:
result = model_pipeline(text_input, **generation_parameters)
# Extract and return the generated text
return result[0]["generated_text"]
except Exception as e:
return f"Error: {str(e)}"
examples = ["""Miluji ženu s vařečkou, která umí vařit.""",
"""Zřejmě to musel fotit nějaký chatař odsousedství, nebo by to mohl taky fotit můj manžel, ale
on se obyčejně k aparátu moc neměl.""",
"""Tomáš se domluvil s Jardou, že ho odveze na nádraží."""]
# Define the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=5, label="Input Text"),
outputs=gr.Textbox(label="Model Output"),
title="Anaphora resolution demo",
description="""Enter text into the \"Input Text\" box, include tags around the anaphora
which is to be resolved. The model generates a copy of the text with tags marking the
predicted antecedent. This demo uses the model based on the mT5 base size model.""",
theme="default",
examples=examples
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch(share=True)