ryantwolf commited on
Commit
28f3e16
·
verified ·
1 Parent(s): 1c2ec51

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +155 -4
README.md CHANGED
@@ -1,9 +1,160 @@
1
  ---
2
  tags:
3
- - model_hub_mixin
4
  - pytorch_model_hub_mixin
 
 
5
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
1
  ---
2
  tags:
 
3
  - pytorch_model_hub_mixin
4
+ - model_hub_mixin
5
+ license: apache-2.0
6
  ---
7
+ # Quality Classifier DeBERTa
8
+
9
+ # Model Overview
10
+ This is a text classification model that can enable qualitative data annotation, creation of quality-specific blends and addition of metadata tags. The model classifies documents into one of three classes based on the quality of the document: "High", "Medium", "Low"
11
+
12
+ The model was trained using data annotated by human annotators, who considered quality factors such as content accuracy, clarity, coherence, grammar, depth of information, and overall usefulness of the document.
13
+
14
+ This model is used in the [NVIDIA NeMo Curator](https://github.com/NVIDIA/NeMo-Curator) as part of the qualitative filtering module.
15
+ # Model Architecture
16
+ The model architecture is Deberta V3 Base
17
+ Context length is 512 tokens
18
+ # Training (details)
19
+ ## Training data:
20
+ - 1 million Common Crawl samples, labeled using Google Cloud’s Natural Language API: https://cloud.google.com/natural-language/docs/classifying-text
21
+ - 500k Wikepedia articles, curated using Wikipedia-API: https://pypi.org/project/Wikipedia-API/
22
+ ## Training steps:
23
+ The training set is 22828 Common Crawl text samples, labeled as "High", "Medium", "Low". Here are some examples:
24
+ 1. Input:
25
+ ```
26
+ Volunteering
27
+
28
+ It's all about the warm, fuzzy feeling when you serve the community, without expectation of gain. Volunteering offers you the necessary experience and development skills to take forward with you, as you venture out to work with other people and apply what you learn, to achieve your career goals.
29
+
30
+ HOW IT WORKS
31
+
32
+ SEARCH
33
+
34
+ BOOK NOW
35
+
36
+ ENJOY THE SHOW
37
+
38
+ GET A FREE QUOTE
39
+
40
+ Planning your event ahead of time is the right move. Contact our experts and let us surprise you.
41
+ ```
42
+ Output: `Low`
43
+ 1. Input:
44
+ ```
45
+ Sharapova has been in New Zealand since well before the New Year, preparing for her 2011 start and requested the opening day match to test her form. "My last tournament was over two months ago and it will be really good to get back playing again."
46
+
47
+ "My priority since I have been here has been to adjust to time and conditions. I have had a couple of practices a day and think that has been really important."
48
+
49
+ The three-time Grand Slam champion who once stood number one next plays Voracova after winning their only previous match in 2003.
50
+ ```
51
+ Output: `High`
52
+
53
+
54
+ # How To Use This Model
55
+ ## Input
56
+ The model takes one or several paragraphs of text as input.
57
+
58
+ Example input:
59
+ ```
60
+ Reasons to visit Thatta
61
+
62
+ Thatta is one of the most important cities of the province of Sindh, Pakistan. Historically it is the richest city. The sands of Thatta have seen many great men. It provided Alexander the Great and his troops a comfortable resting place before they moved further. It welcomed the Mughal Emperor Shah Jehan.
63
+ ```
64
+ ## Output
65
+ The model outputs one of the 3 classes as the predicted quality for each input sample.
66
+
67
+ Example output:
68
+ ```
69
+ Medium
70
+ ```
71
+
72
+ # How to use in NeMo Curator
73
+
74
+ The inference code is available on [NeMo Curator's GitHub repository](https://github.com/NVIDIA/NeMo-Curator). Download the [model.pth](https://huggingface.co/nvidia/quality-classifier-deberta/blob/main/model.pth) and check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/distributed_data_classification/distributed_data_classification.ipynb) to get started.
75
+
76
+ # How to use in transformers
77
+ To use the quality classifier, use the following code:
78
+
79
+ ```python
80
+ import torch
81
+ from torch import nn
82
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
83
+ from huggingface_hub import PyTorchModelHubMixin
84
+
85
+
86
+ class QualityModel(nn.Module, PyTorchModelHubMixin):
87
+ def __init__(self, config):
88
+ super(QualityModel, self).__init__()
89
+ self.model = AutoModel.from_pretrained(config["base_model"])
90
+ self.dropout = nn.Dropout(config["fc_dropout"])
91
+ self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
92
+
93
+ def forward(self, input_ids, attention_mask):
94
+ features = self.model(
95
+ input_ids=input_ids, attention_mask=attention_mask
96
+ ).last_hidden_state
97
+ dropped = self.dropout(features)
98
+ outputs = self.fc(dropped)
99
+ return torch.softmax(outputs[:, 0, :], dim=1)
100
+
101
+
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+
104
+ # Setup configuration and model
105
+ config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
106
+ tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
107
+ model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
108
+ model.eval()
109
+
110
+ # Prepare and process inputs
111
+ text_samples = [".?@fdsa Low quality text.", "This sentence is ok."]
112
+ inputs = tokenizer(
113
+ text_samples, return_tensors="pt", padding="longest", truncation=True
114
+ ).to(device)
115
+ outputs = model(inputs["input_ids"], inputs["attention_mask"])
116
+
117
+ # Predict and display results
118
+ predicted_classes = torch.argmax(outputs, dim=1)
119
+ predicted_domains = [
120
+ config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()
121
+ ]
122
+ print(predicted_domains)
123
+ # ['Low', 'Medium']
124
+ ```
125
+
126
+ # Evaluation Benchmarks
127
+ ## Evaluation data
128
+
129
+ The evaluation data is a subset of training data where all three annotators agree on the label. It has 7128 samples.
130
+
131
+ ## Results
132
+ Accuracy score on evaluation set with 7128 samples - `0.8252`
133
+
134
+
135
+ | Class | Precision | Recall | F1-Score |
136
+ |--------|-----------|--------|----------|
137
+ | High | 0.5043 | 0.1776 | 0.2626 |
138
+ | Medium | 0.8325 | 0.9396 | 0.8825 |
139
+ | Low | 0.8510 | 0.7279 | 0.7842 |
140
+
141
+
142
+ Confusion Matrix:
143
+
144
+ We verify that the predicted scores are indeed close to their ground truth, and are due to the noisy nature of the annotation.
145
+
146
+ | | High | Medium | Low |
147
+ |---------|------|--------|-----|
148
+ | High | 117 | 541 | 1 |
149
+ | Medium | 115 | 4688 | 187 |
150
+ | Low | 0 | 402 | 1077|
151
+
152
+ # Limitations
153
+ - Subjectivity in Quality: Quality assessment is inherently subjective and may vary among different annotators.
154
 
155
+ # References
156
+ - https://arxiv.org/abs/2111.09543
157
+ - https://github.com/microsoft/DeBERTa
158
+ # License
159
+ License to use this model is covered by the Apache 2.0. By downloading the public and release version of the model, you accept the terms and conditions of the Apache License 2.0.
160
+ This repository contains the code for the domain classifier model.