vjawa_update_model_card

#4
by VibhuJawa - opened
Files changed (2) hide show
  1. README.md +3 -127
  2. config.json +0 -1
README.md CHANGED
@@ -2,132 +2,8 @@
2
  tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
5
- license: apache-2.0
6
  ---
7
- # NemoCurator Domain Classifier
8
 
9
- # Model Overview
10
- This is a text classification model to classify documents into one of 26 domain classes:
11
- ```
12
- 'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'
13
- ```
14
-
15
- # Model Architecture
16
- - The model architecture is Deberta V3 Base
17
- - Context length is 512 tokens
18
-
19
- # Training Details
20
- ## Training data:
21
- - 1 million Common Crawl samples, labeled using Google Cloud’s Natural Language API: https://cloud.google.com/natural-language/docs/classifying-text
22
- - 500k Wikepedia articles, curated using Wikipedia-API: https://pypi.org/project/Wikipedia-API/
23
-
24
- ## Training steps:
25
- Model was trained in multiple rounds using Wikipedia and Common Crawl data, labeled by a combination of pseudo labels and Google Cloud API.
26
-
27
- # How To Use This Model
28
- ## Input
29
- The model takes one or several paragraphs of text as input.
30
- Example input:
31
- ```
32
- q Directions
33
- 1. Mix 2 flours and baking powder together
34
- 2. Mix water and egg in a separate bowl. Add dry to wet little by little
35
- 3. Heat frying pan on medium
36
- 4. Pour batter into pan and then put blueberries on top before flipping
37
- 5. Top with desired toppings!
38
- ```
39
- ## Output
40
- The model outputs one of the 26 domain classes as the predicted domain for each input sample.
41
- Example output:
42
- ```
43
- Food_and_Drink
44
- ```
45
-
46
- # How to Use in NVIDIA NeMo Curator
47
-
48
- The inference code is available on [NeMo Curator's GitHub repository](https://github.com/NVIDIA/NeMo-Curator). Check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/distributed_data_classification/domain-classification.ipynb) to get started.
49
-
50
- # How to Use in Transformers
51
- To use the domain classifier, use the following code:
52
-
53
- ```python
54
- import torch
55
- from torch import nn
56
- from transformers import AutoModel, AutoTokenizer, AutoConfig
57
- from huggingface_hub import PyTorchModelHubMixin
58
-
59
- class CustomModel(nn.Module, PyTorchModelHubMixin):
60
- def __init__(self, config):
61
- super(CustomModel, self).__init__()
62
- self.model = AutoModel.from_pretrained(config["base_model"])
63
- self.dropout = nn.Dropout(config["fc_dropout"])
64
- self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
65
-
66
- def forward(self, input_ids, attention_mask):
67
- features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
68
- dropped = self.dropout(features)
69
- outputs = self.fc(dropped)
70
- return torch.softmax(outputs[:, 0, :], dim=1)
71
-
72
- # Setup configuration and model
73
- config = AutoConfig.from_pretrained("nvidia/domain-classifier")
74
- tokenizer = AutoTokenizer.from_pretrained("nvidia/domain-classifier")
75
- model = CustomModel.from_pretrained("nvidia/domain-classifier")
76
- model.eval()
77
-
78
- # Prepare and process inputs
79
- text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
80
- inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
81
- outputs = model(inputs["input_ids"], inputs["attention_mask"])
82
-
83
- # Predict and display results
84
- predicted_classes = torch.argmax(outputs, dim=1)
85
- predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
86
- print(predicted_domains)
87
- # ['Sports', 'News']
88
- ```
89
-
90
- # Evaluation Benchmarks
91
-
92
- Evaluation Metric: PR-AUC
93
-
94
- PR-AUC score on evaluation set with 105k samples - 0.9873
95
-
96
- PR-AUC score for each domain:
97
- | Domain | PR-AUC |
98
- |:-------------------------|:-------|
99
- | Adult | 0.999 |
100
- | Arts_and_Entertainment | 0.997 |
101
- | Autos_and_Vehicles | 0.997 |
102
- | Beauty_and_Fitness | 0.997 |
103
- | Books_and_Literature | 0.995 |
104
- | Business_and_Industrial | 0.982 |
105
- | Computers_and_Electronics| 0.992 |
106
- | Finance | 0.989 |
107
- | Food_and_Drink | 0.998 |
108
- | Games | 0.997 |
109
- | Health | 0.997 |
110
- | Hobbies_and_Leisure | 0.984 |
111
- | Home_and_Garden | 0.997 |
112
- | Internet_and_Telecom | 0.982 |
113
- | Jobs_and_Education | 0.993 |
114
- | Law_and_Government | 0.967 |
115
- | News | 0.918 |
116
- | Online_Communities | 0.983 |
117
- | People_and_Society | 0.975 |
118
- | Pets_and_Animals | 0.997 |
119
- | Real_Estate | 0.997 |
120
- | Science | 0.988 |
121
- | Sensitive_Subjects | 0.982 |
122
- | Shopping | 0.995 |
123
- | Sports | 0.995 |
124
- | Travel_and_Transportation| 0.996 |
125
- | Mean | 0.9873 |
126
-
127
- # References
128
- - https://arxiv.org/abs/2111.09543
129
- - https://github.com/microsoft/DeBERTa
130
-
131
- # License
132
- License to use this model is covered by the Apache 2.0. By downloading the public and release version of the model, you accept the terms and conditions of the Apache License 2.0.
133
- This repository contains the code for the domain classifier model.
 
2
  tags:
3
  - pytorch_model_hub_mixin
4
  - model_hub_mixin
 
5
  ---
 
6
 
7
+ This model has been pushed to the Hub using ****:
8
+ - Repo: [More Information Needed]
9
+ - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -1,6 +1,5 @@
1
  {
2
  "base_model": "microsoft/deberta-v3-base",
3
- "model_type": "deberta-v2",
4
  "config_path": null,
5
  "fc_dropout": 0.2,
6
  "id2label": {
 
1
  {
2
  "base_model": "microsoft/deberta-v3-base",
 
3
  "config_path": null,
4
  "fc_dropout": 0.2,
5
  "id2label": {