Edit model card

Model Card for robeczech-2stage-supportive-interactions-cs

This model is fine-tuned for 2nd stage multi-label text classification of Supportive Interactions in Instant Messenger dialogs of Adolescents - it expects inputs where at least one of the classes appears.

Model Description

The model was fine-tuned on a dataset of Instant Messenger dialogs of Adolescents. The classification is 2stage and the model outputs probablities for labels {0,1,2,3,4}:

  1. Informational Support
  2. Emotional Support
  3. Social Companionship
  4. Appraisal
  5. Instrumental Support
  • Developed by: Anonymous
  • Language(s): cs
  • Finetuned from: ufal/robeczech

Model Sources

Usage

Here is how to use this model to classify a context-window of a dialogue:

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Prepare input texts. This model is pretrained on multi-lingual data 
# and fine-tuned on English
test_texts = ['Utterance1;Utterance2;Utterance3']

# Load the model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
    'justtherightsize/robeczech-2stage-supportive-interactions-cs', num_labels=5).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(
    'justtherightsize/robeczech-2stage-supportive-interactions-cs',
    use_fast=False, truncation_side='left')
assert tokenizer.truncation_side == 'left'

# Define helper functions
def predict_one(text: str, tok, mod, threshold=0.5):
    encoding = tok(text, return_tensors="pt", truncation=True, padding=True,
                   max_length=256)
    encoding = {k: v.to(mod.device) for k, v in encoding.items()}
    outputs = mod(**encoding)
    logits = outputs.logits
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(logits.squeeze().cpu())
    predictions = np.zeros(probs.shape)
    predictions[np.where(probs >= threshold)] = 1
    return predictions, probs

def print_predictions(texts):
    preds = [predict_one(tt, tokenizer, model) for tt in texts]
    for c, p in preds:
        print(f'{c}: {p.tolist():.4f}')

# Run the prediction
print_predictions(test_texts)
Downloads last month
1
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.