--- inference: false license: mit tags: - Zero-Shot Classification language: - multilingual - af - am - ar - as - az - be - bg - bn - br - bs - ca - cs - cy - da - de - el - en - eo - es - et - eu - fa - fi - fr - fy - ga - gd - gl - gu - ha - he - hi - hr - hu - hy - id - is - it - ja - jv - ka - kk - km - kn - ko - ku - ky - la - lo - lt - lv - mg - mk - ml - mn - mr - ms - my - ne - nl - 'no' - om - or - pa - pl - ps - pt - ro - ru - sa - sd - si - sk - sl - so - sq - sr - su - sv - sw - ta - te - th - tl - tr - ug - uk - ur - uz - vi - xh - yi - zh pipeline_tag: zero-shot-classification metrics: - accuracy --- # Zero-shot text classification (multilingual version) trained with self-supervised tuning Zero-shot text classification model trained with self-supervised tuning (SSTuning). It was introduced in the paper [Zero-Shot Text Classification via Self-Supervised Tuning](https://arxiv.org/abs/2305.11442) by Chaoqun Liu, Wenxuan Zhang, Guizhen Chen, Xiaobao Wu, Anh Tuan Luu, Chip Hong Chang, Lidong Bing and first released in [this repository](https://github.com/DAMO-NLP-SG/SSTuning). The model backbone is xlm-roberta-base. ## Model description The model is tuned with unlabeled data using a first sentence prediction (FSP) learning objective. The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks. The training and validation sets are constructed from the unlabeled corpus using FSP. During tuning, BERT-like pre-trained masked language models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added. The learning objective for FSP is to predict the index of the correct label. A cross-entropy loss is used for tuning the model. ## Model variations There are four versions of models released. The details are: | Model | Backbone | #params | lang | acc | Speed | #Train |------------|-----------|----------|-------|-------|----|-------------| | [zero-shot-classify-SSTuning-base](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-base) | [roberta-base](https://huggingface.co/roberta-base) | 125M | En | Low | High | 20.48M | | [zero-shot-classify-SSTuning-large](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-large) | [roberta-large](https://huggingface.co/roberta-large) | 355M | En | Medium | Medium | 5.12M | | [zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT) | [albert-xxlarge-v2](https://huggingface.co/albert-xxlarge-v2) | 235M | En | High | Low| 5.12M | | [zero-shot-classify-SSTuning-XLM-R](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R) | [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | 278M | Multi | - | - | 20.48M | Please note that zero-shot-classify-SSTuning-XLM-R is trained with 20.48M English samples only. However, it can also be used in other languages as long as xlm-roberta supports. Please check [this repository](https://github.com/DAMO-NLP-SG/SSTuning) for the performance of each model. ## Intended uses & limitations The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed. The number of labels should be 2 ~ 20. ### How to use You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing). ```python from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch, string, random tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R") model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R") text = "I love this place! The food is always so fresh and delicious." list_label = ["negative", "positive"] device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') list_ABC = [x for x in string.ascii_uppercase] def check_text(model, text, list_label, shuffle=False): list_label = [x+'.' if x[-1] != '.' else x for x in list_label] list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label)) if shuffle: random.shuffle(list_label_new) s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))]) text = f'{s_option} {tokenizer.sep_token} {text}' model.to(device).eval() encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt') item = {key: val.to(device) for key, val in encoding.items()} logits = model(**item).logits logits = logits if shuffle else logits[:,0:len(list_label)] probs = torch.nn.functional.softmax(logits, dim = -1).tolist() predictions = torch.argmax(logits, dim=-1).item() probabilities = [round(x,5) for x in probs[0]] print(f'prediction: {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}') print(f'probability: {round(probabilities[predictions]*100,2)}%') check_text(model, text, list_label) # prediction: 1 => (B) positive. # probability: 99.92% ``` ### BibTeX entry and citation info ```bibtxt @inproceedings{acl23/SSTuning, author = {Chaoqun Liu and Wenxuan Zhang and Guizhen Chen and Xiaobao Wu and Anh Tuan Luu and Chip Hong Chang and Lidong Bing}, title = {Zero-Shot Text Classification via Self-Supervised Tuning}, booktitle = {Findings of the Association for Computational Linguistics: ACL 2023}, year = {2023}, url = {https://arxiv.org/abs/2305.11442}, } ```