chkla's picture
add upload todo for the future
d3adc06
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr
import torch
import pandas as pd
# Load trained model
model = AutoModelForSequenceClassification.from_pretrained("chkla/parlbert-topic-german")
tokenizer = AutoTokenizer.from_pretrained("bert-base-german-cased")
pipeline_classification_topics = pipeline("text-classification", model="chkla/parlbert-topic-german", tokenizer="bert-base-german-cased", return_all_scores=False)
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
def predict_topic(input_text):
prediction = pipeline_classification_topics(input_text)
# predicted_label = prediction[0]['label']
return prediction[0]['label']
# Build Gradio interface
with gr.Blocks() as demo:
# Instruction
gr.Markdown('''## Topic Modelling for German Political Texts''')
# Input text to be reframed
text = gr.Textbox(label="Text")
# predictions = gr.outputs.Textbox(label="Topic")
predictions = gr.Label()
# Trigger button for topic prediction
greet_btn = gr.Button("Predict Topic")
greet_btn.click(fn=predict_topic, inputs=[text], outputs=[predictions])
# Default examples of text and strategy pairs for user to have a quick start
gr.Markdown("## Examples")
gr.Examples(
["Sachgebiet Ausschließliche Gesetzgebungskompetenz des Bundes über die Zusammenarbeit des Bundes und der Länder zum Schutze der freiheitlichen demokratischen Grundordnung, des Bestandes und der Sicherheit des Bundes oder eines Landes.", "Sachgebiet Investive Ausgaben des Bundes Bundesfinanzminister Apel hat gemäß BMF Finanznachrichten vom . Januar erklärt , die Investitionsquote des Bundes sei in den letzten zehn Jahren nahezu konstant geblieben."],
[text],
[predictions],
fn=predict_topic,
cache_examples=True,
)
# Upload file
# file_output = gr.File()
# upload_button = gr.UploadButton("Click to Upload a File", outputs=[file_output])
# gr.Dataframe(pd.DataFrame(file_output))
# Link to paper and Github repo
gr.Markdown('''For more details: You can read our [paper](http://www.lrec-conf.org/proceedings/lrec2022/workshops/ParlaCLARINIII/pdf/2022.parlaclariniii-1.13.pdf) or access our [code](https://github.com/chkla/FrameASt).''')
gr.Markdown('''Enjoy and stay tuned 🚀''')
def main():
demo.launch()
if __name__ == "__main__":
main()