daekeun-ml commited on
Commit
688ddc3
โ€ข
1 Parent(s): acbe9b1

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +118 -0
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Similarity between two sentences (fine tuning with KoELECTRA-Small-v3 model and KorSTS dataset)
2
+
3
+ ## Usage (Amazon SageMaker inference applicable)
4
+ It uses the interface of the SageMaker Inference Toolkit as is, so it can be easily deployed to SageMaker Endpoint.
5
+
6
+ ### inference_korsts.py
7
+
8
+ ```python
9
+ import json
10
+ import sys
11
+ import logging
12
+ import torch
13
+ from torch import nn
14
+ from transformers import ElectraConfig
15
+ from transformers import ElectraModel, AutoTokenizer, ElectraTokenizer, ElectraForSequenceClassification
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.FileHandler(filename='tmp.log'),
22
+ logging.StreamHandler(sys.stdout)
23
+ ]
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ max_seq_length = 128
28
+ tokenizer = AutoTokenizer.from_pretrained("daekeun-ml/koelectra-small-v3-korsts")
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+
32
+ # Huggingface pre-trained model: 'monologg/koelectra-small-v3-discriminator'
33
+ def model_fn(model_path):
34
+ ####
35
+ # If you have your own trained model
36
+ # Huggingface pre-trained model: 'monologg/koelectra-small-v3-discriminator'
37
+ ####
38
+ #config = ElectraConfig.from_json_file(f'{model_path}/config.json')
39
+ #model = ElectraForSequenceClassification.from_pretrained(f'{model_path}/model.pth', config=config)
40
+ model = ElectraForSequenceClassification.from_pretrained('daekeun-ml/koelectra-small-v3-korsts')
41
+ model.to(device)
42
+ return model
43
+
44
+
45
+ def input_fn(input_data, content_type="application/jsonlines"):
46
+ data_str = input_data.decode("utf-8")
47
+ jsonlines = data_str.split("\n")
48
+ transformed_inputs = []
49
+
50
+ for jsonline in jsonlines:
51
+ text = json.loads(jsonline)["text"]
52
+ logger.info("input text: {}".format(text))
53
+ encode_plus_token = tokenizer.encode_plus(
54
+ text,
55
+ max_length=max_seq_length,
56
+ add_special_tokens=True,
57
+ return_token_type_ids=False,
58
+ padding="max_length",
59
+ return_attention_mask=True,
60
+ return_tensors="pt",
61
+ truncation=True,
62
+ )
63
+ transformed_inputs.append(encode_plus_token)
64
+
65
+ return transformed_inputs
66
+
67
+
68
+ def predict_fn(transformed_inputs, model):
69
+ predicted_classes = []
70
+
71
+ for data in transformed_inputs:
72
+ data = data.to(device)
73
+ output = model(**data)
74
+
75
+ prediction_dict = {}
76
+ prediction_dict['score'] = output[0].squeeze().cpu().detach().numpy().tolist()
77
+
78
+ jsonline = json.dumps(prediction_dict)
79
+ logger.info("jsonline: {}".format(jsonline))
80
+ predicted_classes.append(jsonline)
81
+
82
+ predicted_classes_jsonlines = "\n".join(predicted_classes)
83
+ return predicted_classes_jsonlines
84
+
85
+
86
+ def output_fn(outputs, accept="application/jsonlines"):
87
+ return outputs, accept
88
+ ```
89
+
90
+ ### test.py
91
+ ```python
92
+ >>> from inference_korsts import model_fn, input_fn, predict_fn, output_fn
93
+ >>> with open('./samples/korsts.txt', mode='rb') as file:
94
+ >>> model_input_data = file.read()
95
+
96
+ >>> model = model_fn()
97
+ >>> transformed_inputs = input_fn(model_input_data)
98
+ >>> predicted_classes_jsonlines = predict_fn(transformed_inputs, model)
99
+ >>> model_outputs = output_fn(predicted_classes_jsonlines)
100
+ >>> print(model_outputs[0])
101
+
102
+ [{inference_korsts.py:44} INFO - input text: ['๋ง›์žˆ๋Š” ๋ผ๋ฉด์„ ๋จน๊ณ  ์‹ถ์–ด์š”', 'ํ›„๋ฃจ๋ฃฉ ์ฉ์ฉ ํ›„๋ฃจ๋ฃฉ ์ฉ์ฉ ๋ง›์ข‹์€ ๋ผ๋ฉด']
103
+ [{inference_korsts.py:44} INFO - input text: ['๋ฝ€๋กœ๋กœ๋Š” ๋‚ด์นœ๊ตฌ', '๋จธ์‹ ๋Ÿฌ๋‹์€ ๋Ÿฌ๋‹๋จธ์‹ ์ด ์•„๋‹™๋‹ˆ๋‹ค.']
104
+ [{inference_korsts.py:71} INFO - jsonline: {"score": 4.786738872528076}
105
+ [{inference_korsts.py:71} INFO - jsonline: {"score": 0.2319069355726242}
106
+ {"score": 4.786738872528076}
107
+ {"score": 0.2319069355726242}
108
+ ```
109
+
110
+ ### Sample data (samples/korsts.txt)
111
+ ```
112
+ {"text": ["๋ง›์žˆ๋Š” ๋ผ๋ฉด์„ ๋จน๊ณ  ์‹ถ์–ด์š”", "ํ›„๋ฃจ๋ฃฉ ์ฉ์ฉ ํ›„๋ฃจ๋ฃฉ ์ฉ์ฉ ๋ง›์ข‹์€ ๋ผ๋ฉด"]}
113
+ {"text": ["๋ฝ€๋กœ๋กœ๋Š” ๋‚ด์นœ๊ตฌ", "๋จธ์‹ ๋Ÿฌ๋‹์€ ๋Ÿฌ๋‹๋จธ์‹ ์ด ์•„๋‹™๋‹ˆ๋‹ค."]}
114
+ ```
115
+
116
+ ## References
117
+ - KoELECTRA: https://github.com/monologg/KoELECTRA
118
+ - KorNLI and KorSTS: https://github.com/kakaobrain/KorNLUDatasets