Files changed (2) hide show
  1. README.md +3 -121
  2. config.json +0 -1
README.md CHANGED
@@ -2,126 +2,8 @@
2
  tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
5
- license: apache-2.0
6
  ---
7
- # Domain Classifier
8
 
9
- # Model Overview
10
- This is a text classification model to classify documents into one of 26 domain classes:
11
-
12
- 'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'
13
- # Model Architecture
14
- The model architecture is Deberta V3 Base
15
- Context length is 512 tokens
16
- # Training (details)
17
- ## Training data:
18
- - 1 million Common Crawl samples, labeled using Google Cloud’s Natural Language API: https://cloud.google.com/natural-language/docs/classifying-text
19
- - 500k Wikepedia articles, curated using Wikipedia-API: https://pypi.org/project/Wikipedia-API/
20
- ## Training steps:
21
- Model was trained in multiple rounds using Wikipedia and Common Crawl data, labeled by a combination of pseudo labels and Google Cloud API.
22
- # How To Use This Model
23
- ## Input
24
- The model takes one or several paragraphs of text as input.
25
- Example input:
26
- ```
27
- q Directions
28
- 1. Mix 2 flours and baking powder together
29
- 2. Mix water and egg in a separate bowl. Add dry to wet little by little
30
- 3. Heat frying pan on medium
31
- 4. Pour batter into pan and then put blueberries on top before flipping
32
- 5. Top with desired toppings!
33
- ```
34
- ## Output
35
- The model outputs one of the 26 domain classes as the predicted domain for each input sample.
36
- Example output:
37
- ```
38
- Food_and_Drink
39
- ```
40
-
41
- # How to use in NeMo Curator
42
-
43
- The inference code is available on [NeMo Curator's GitHub repository](https://github.com/NVIDIA/NeMo-Curator). Download the [model.pth](https://huggingface.co/nvidia/domain-classifier/blob/main/model.pth) file and check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/distributed_data_classification/distributed_data_classification.ipynb) to get started.
44
-
45
- # How to use in transformers
46
- To use the Domain classifier, use the following code:
47
-
48
- ```python
49
-
50
- import torch
51
- from torch import nn
52
- from transformers import AutoModel, AutoTokenizer, AutoConfig
53
- from huggingface_hub import PyTorchModelHubMixin
54
-
55
- class CustomModel(nn.Module, PyTorchModelHubMixin):
56
- def __init__(self, config):
57
- super(CustomModel, self).__init__()
58
- self.model = AutoModel.from_pretrained(config['base_model'])
59
- self.dropout = nn.Dropout(config['fc_dropout'])
60
- self.fc = nn.Linear(self.model.config.hidden_size, len(config['id2label']))
61
-
62
- def forward(self, input_ids, attention_mask):
63
- features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
64
- dropped = self.dropout(features)
65
- outputs = self.fc(dropped)
66
- return torch.softmax(outputs[:, 0, :], dim=1)
67
-
68
- # Setup configuration and model
69
- config = AutoConfig.from_pretrained("nvidia/domain-classifier")
70
- tokenizer = AutoTokenizer.from_pretrained("nvidia/domain-classifier")
71
- model = CustomModel.from_pretrained("nvidia/domain-classifier")
72
-
73
- # Prepare and process inputs
74
- text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
75
- inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
76
- outputs = model(inputs['input_ids'], inputs['attention_mask'])
77
-
78
- # Predict and display results
79
- predicted_classes = torch.argmax(outputs, dim=1)
80
- predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
81
- print(predicted_domains)
82
- # ['Sports', 'News']
83
- ```
84
-
85
- # Evaluation Benchmarks
86
-
87
- Evaluation Metric: PR-AUC
88
-
89
- PR-AUC score on evaluation set with 105k samples - 0.9873
90
-
91
- PR-AUC score for each domain:
92
- | Domain | PR-AUC |
93
- |:-------------------------|:-------|
94
- | Adult | 0.999 |
95
- | Arts_and_Entertainment | 0.997 |
96
- | Autos_and_Vehicles | 0.997 |
97
- | Beauty_and_Fitness | 0.997 |
98
- | Books_and_Literature | 0.995 |
99
- | Business_and_Industrial | 0.982 |
100
- | Computers_and_Electronics| 0.992 |
101
- | Finance | 0.989 |
102
- | Food_and_Drink | 0.998 |
103
- | Games | 0.997 |
104
- | Health | 0.997 |
105
- | Hobbies_and_Leisure | 0.984 |
106
- | Home_and_Garden | 0.997 |
107
- | Internet_and_Telecom | 0.982 |
108
- | Jobs_and_Education | 0.993 |
109
- | Law_and_Government | 0.967 |
110
- | News | 0.918 |
111
- | Online_Communities | 0.983 |
112
- | People_and_Society | 0.975 |
113
- | Pets_and_Animals | 0.997 |
114
- | Real_Estate | 0.997 |
115
- | Science | 0.988 |
116
- | Sensitive_Subjects | 0.982 |
117
- | Shopping | 0.995 |
118
- | Sports | 0.995 |
119
- | Travel_and_Transportation| 0.996 |
120
- | Mean | 0.9873 |
121
-
122
- # References
123
- - https://arxiv.org/abs/2111.09543
124
- - https://github.com/microsoft/DeBERTa
125
- # License
126
- License to use this model is covered by the Apache 2.0. By downloading the public and release version of the model, you accept the terms and conditions of the Apache License 2.0.
127
- This repository contains the code for the domain classifier model.
 
2
  tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
 
5
  ---
 
6
 
7
+ This model has been pushed to the Hub using ****:
8
+ - Repo: [More Information Needed]
9
+ - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -1,6 +1,5 @@
1
  {
2
  "base_model": "microsoft/deberta-v3-base",
3
- "model_type": "deberta-v2",
4
  "config_path": null,
5
  "fc_dropout": 0.2,
6
  "id2label": {
 
1
  {
2
  "base_model": "microsoft/deberta-v3-base",
 
3
  "config_path": null,
4
  "fc_dropout": 0.2,
5
  "id2label": {