cbdb commited on
Commit
51fd61b
·
verified ·
1 Parent(s): 87a461c

Fix colab link

Browse files
Files changed (1) hide show
  1. README.md +53 -41
README.md CHANGED
@@ -12,7 +12,7 @@ license: cc-by-nc-sa-4.0
12
  ---
13
 
14
  # <font color="IndianRed"> OTAS (Office Title Address Splitter)</font>
15
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UoG3QebyBlK6diiYckiQv-5dRB9dA4iv?usp=sharing/)
16
 
17
  Our model <font color="cornflowerblue">OTAS (Office Title Address Splitter) </font> is a Named Entity Recognition Classical Chinese language model that is intended to <font color="IndianRed">split the address portion in Classical Chinese office titles.</font>. This model is first inherited from raynardj/classical-chinese-punctuation-guwen-biaodian Classical Chinese punctuation model, and finetuned using over a 25,000 high-quality punctuation pairs collected CBDB group (China Biographical Database).
18
 
@@ -24,57 +24,69 @@ Here is how to use this model to get the features of a given text in PyTorch:
24
  ```python
25
  from transformers import AutoTokenizer, AutoModelForTokenClassification
26
 
27
- device = torch.device('cuda')
28
- model_name = 'cbdb/OfficeTitleAddressSplitter'
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- model = AutoModelForTokenClassification.from_pretrained(model_name).to(device)
31
  ```
32
 
33
  <font color="cornflowerblue"> 2. Load Data </font>
34
  ```python
35
  # Load your data here
36
- tobe_splitted = ['湖南常德協中軍都司','廣東鹽運使','漢軍鑲黃旗副都統']
37
  ```
38
 
39
- work-in-progress
40
 
41
  <font color="cornflowerblue"> 3. Make a prediction </font>
42
  ```python
43
-
44
- tokens_test = tokenizer.encode_plus(
45
- tobe_splitted,
46
- add_special_tokens=True,
47
- return_attention_mask=True,
48
- padding=True,
49
- max_length=max_seq_len,
50
- return_tensors='pt',
51
- truncation=True
52
- )
53
-
54
- test_seq = torch.tensor(tokens_test['input_ids'])
55
- test_mask = torch.tensor(tokens_test['attention_mask'])
56
-
57
- # get predictions for test data
58
- with torch.no_grad():
59
- outputs = model(test_seq.cuda(), test_mask.cuda())
60
- outputs = outputs.logits.detach().cpu().numpy()
61
-
62
- softmax_score = softmax(outputs)
63
- # pred_class_dict = {k:v for k, v in zip(label2idx.keys(), softmax_score[0])}
64
- softmax_score = np.argmax(softmax_score, axis=2)[0]
65
-
66
- inputs = tokenizer(tobe_splitted, return_tensors="pt", padding=True).to(device)
67
- translated = model.generate(**inputs, max_length=128)
68
- tran = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
69
- for c, t in zip(tobe_translated, tran):
70
- print(f'{c}: {t}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ```
72
- 講筵官: Lecturer<br>
73
- 判司簿尉: Supervisor of the Commandant of Records<br>
74
- 散騎常侍: Policy Advisor<br>
75
- 殿中省尚輦奉御: Chief Steward of the Palace Administration<br>
76
-
77
- work-in-progress
78
 
79
 
80
  ### <font color="IndianRed">Authors </font>
 
12
  ---
13
 
14
  # <font color="IndianRed"> OTAS (Office Title Address Splitter)</font>
15
+ [![Open In Colab](https://colab.research.google.com/drive/1UoG3QebyBlK6diiYckiQv-5dRB9dA4iv?usp=sharing)
16
 
17
  Our model <font color="cornflowerblue">OTAS (Office Title Address Splitter) </font> is a Named Entity Recognition Classical Chinese language model that is intended to <font color="IndianRed">split the address portion in Classical Chinese office titles.</font>. This model is first inherited from raynardj/classical-chinese-punctuation-guwen-biaodian Classical Chinese punctuation model, and finetuned using over a 25,000 high-quality punctuation pairs collected CBDB group (China Biographical Database).
18
 
 
24
  ```python
25
  from transformers import AutoTokenizer, AutoModelForTokenClassification
26
 
27
+ PRETRAINED = "cbdb/OfficeTitleAddressSplitter"
28
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
29
+ model = AutoModelForTokenClassification.from_pretrained(PRETRAINED)
 
30
  ```
31
 
32
  <font color="cornflowerblue"> 2. Load Data </font>
33
  ```python
34
  # Load your data here
35
+ test_list = ['漢軍鑲黃旗副都統', '兵部右侍郎', '盛京戶部侍郎']
36
  ```
37
 
 
38
 
39
  <font color="cornflowerblue"> 3. Make a prediction </font>
40
  ```python
41
+ def predict_class(test):
42
+ tokens_test = tokenizer.encode_plus(
43
+ test,
44
+ add_special_tokens=True,
45
+ return_attention_mask=True,
46
+ padding=True,
47
+ max_length=128,
48
+ return_tensors='pt',
49
+ truncation=True
50
+ )
51
+
52
+ test_seq = torch.tensor(tokens_test['input_ids'])
53
+ test_mask = torch.tensor(tokens_test['attention_mask'])
54
+
55
+ inputs = {
56
+ "input_ids": test_seq,
57
+ "attention_mask": test_mask
58
+ }
59
+ with torch.no_grad():
60
+ # print(inputs.shape)
61
+ outputs = model(**inputs)
62
+ outputs = outputs.logits.detach().cpu().numpy()
63
+
64
+ softmax_score = softmax(outputs)
65
+ softmax_score = np.argmax(softmax_score, axis=2)[0]
66
+ return test_seq, softmax_score
67
+
68
+ for test_sen0 in test_list:
69
+ test_seq, pred_class_proba = predict_class(test_sen0)
70
+ test_sen = tokenizer.decode(test_seq[0]).split()
71
+ label = [idx2label[i] for i in pred_class_proba]
72
+
73
+ element_to_find = '。'
74
+
75
+ if element_to_find in label:
76
+ index = label.index(element_to_find)
77
+ test_sen_pred = [i for i in test_sen0]
78
+ test_sen_pred.insert(index, element_to_find)
79
+ test_sen_pred = ''.join(test_sen_pred)
80
+
81
+ else:
82
+ test_sen_pred = [i for i in test_sen0]
83
+ test_sen_pred = ''.join(test_sen_pred)
84
+
85
+ print(test_sen_pred)
86
  ```
87
+ 漢軍鑲黃旗。副都統<br>
88
+ 兵部右侍郎<br>
89
+ 盛京。戶部侍郎<br>
 
 
 
90
 
91
 
92
  ### <font color="IndianRed">Authors </font>