pkuHaowei commited on
Commit
6dac9bf
1 Parent(s): c324592

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -0
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You can easily import our continually post-trained model with HuggingFace's `transformers`:
2
+
3
+ ```python
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # Import our model. The package will take care of downloading the models automatically
8
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
9
+ model = AutoModelForSequenceClassification.from_pretrained("UIC-Liu-Lab/CPT", trust_remote_code=True)
10
+
11
+ # Tokenize input texts
12
+ texts = [
13
+ "There's a kid on a skateboard.",
14
+ "A kid is skateboarding.",
15
+ "A kid is inside the house."
16
+ ]
17
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
18
+
19
+ # Task id and smax
20
+ t = torch.LongTensor([0]).to(model.device) # using task 0's CL-plugin, choose from {0, 1, 2, 3}
21
+ smax = 400
22
+
23
+ # Get the model output!
24
+ res = model(**inputs, return_dict=True, t=t, s=smax)
25
+ ```
26
+
27
+ If you encounter any problem when directly loading the models by HuggingFace's API, you can also download the models manually from the [repo](https://huggingface.co/UIC-Liu-Lab/CPT/tree/main) and use `model = AutoModel.from_pretrained({PATH TO THE DOWNLOAD MODEL})`.
28
+
29
+ Note: The post-trained weights you load contain un-trained classification heads. The post-training sequence is `Restaurant -> AI -> ACL -> AGNews`, you can use the downloaded weights to fine-tune the corresponding end-task. The results (MF1/Acc) will be consistent with follows.
30
+
31
+ | | Restaurant | AI | ACL | AGNews | Avg. |
32
+ | --------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
33
+ | UIC-Liu-Lab/CPT | 53.90 / 75.13 | 30.42 / 30.89 | 37.56 / 38.53 | 63.77 / 65.79 | 46.41 / 52.59 |
34
+
35
+ ##