|
"""Gradio app that showcases Scandinavian zero-shot text classification models.""" |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
from luga import language as detect_language |
|
|
|
|
|
|
|
classifier = pipeline( |
|
"zero-shot-classification", model="alexandrainst/scandi-nli-large" |
|
) |
|
|
|
|
|
def sentiment_classification(doc: str) -> str: |
|
"""Classify text into sentiment categories. |
|
|
|
Args: |
|
doc (str): |
|
Text to classify. |
|
|
|
Returns: |
|
str: |
|
The predicted sentiment category. |
|
""" |
|
|
|
language = detect_language(doc).name |
|
|
|
|
|
if language == "da": |
|
hypothesis_template = "Dette eksempel er {}." |
|
candidate_labels = ["positivt", "negativt", "neutralt"] |
|
elif language == "sv": |
|
hypothesis_template = "Detta exempel är {}." |
|
candidate_labels = ["positivt", "negativt", "neutralt"] |
|
elif language == "no": |
|
hypothesis_template = "Dette eksemplet er {}." |
|
candidate_labels = ["positivt", "negativt", "nøytralt"] |
|
|
|
|
|
result = classifier( |
|
doc, candidate_labels=candidate_labels, hypothesis_template=hypothesis_template |
|
) |
|
|
|
|
|
return result["labels"][0] |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=sentiment_classification, |
|
inputs=gr.inputs.Textbox(lines=5, label="Text"), |
|
outputs=gr.outputs.Label(type="text"), |
|
title="Scandinavian Zero-Shot Text Classification", |
|
description="Classify text into sentiment categories.", |
|
) |
|
|
|
|
|
interface.launch() |
|
|