Copycats commited on
Commit
1491fa3
โ€ข
1 Parent(s): 7afe39f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -32
README.md CHANGED
@@ -37,46 +37,52 @@ license: cc-by-nc-4.0
37
  import torch
38
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
39
 
40
- context = "your context"
41
- question = "your question"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Load fine-tuned MRC model by HuggingFace Model Hub
44
- HUGGINGFACE_MODEL_PATH = "bespin-global/klue-bert-base-aihub-mrc"
45
- tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_PATH )
46
- model = AutoModelForQuestionAnswering.from_pretrained(HUGGINGFACE_MODEL_PATH )
 
47
 
48
  # gpu or cpu
49
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
50
  model.to(device)
51
  model.eval()
52
 
53
- # Encoding
54
- encodings = tokenizer(
55
- context, question,
56
- max_length=512,
57
- truncation=True,
58
- padding="max_length",
59
- return_token_type_ids=False,
60
- return_offsets_mapping=True
61
- )
62
- encodings = {key: torch.tensor([val]) for key, val in encodings.items()}
63
- input_ids = encodings["input_ids"].to(device)
64
- attention_mask = encodings["attention_mask"].to(device)
65
- offset_mappings = encodings["offset_mapping"].to(device)
66
-
67
- # Predict
68
- pred = model(input_ids, attention_mask=attention_mask)
69
- start_logits, end_logits = pred.start_logits, pred.end_logits
70
- token_start_index, token_end_index = start_logits.argmax(dim=-1), end_logits.argmax(dim=-1)
71
- pred_ids = input_ids[0][token_start_index: token_end_index + 1]
72
-
73
- # Answer start/end offset of context.
74
- answer_start_offset = int(offset_mappings[0][token_start_index][0][0])
75
- answer_end_offset = int(offset_mappings[0][token_end_index][0][1])
76
 
77
- # Answer text
78
- answer_text = tokenizer.decode(pred_ids)
79
- print(f"ANSWER : {answer_text}")
 
80
  ```
81
 
82
  ## Citing & Authors
37
  import torch
38
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
39
 
40
+ def predict_answer(qa_text_pair):
41
+ # Encoding
42
+ encodings = tokenizer(
43
+ qa_text_pair['question'], qa_text_pair['context'],
44
+ max_length=512,
45
+ truncation=True,
46
+ padding="max_length",
47
+ return_token_type_ids=False,
48
+ return_offsets_mapping=True
49
+ )
50
+ encodings = {key: torch.tensor([val]).to(device) for key, val in encodings.items()}
51
+ # Predict
52
+ with torch.no_grad():
53
+ pred = model(encodings['input_ids'], encodings['attention_mask'])
54
+ start_logits, end_logits = pred.start_logits, pred.end_logits
55
+ token_start_index, token_end_index = start_logits.argmax(dim=-1), end_logits.argmax(dim=-1)
56
+ pred_ids = encodings['input_ids'][0][token_start_index: token_end_index + 1]
57
+ # Answer start/end offset of context.
58
+ answer_start_offset = int(encodings['offset_mapping'][0][token_start_index][0][0])
59
+ answer_end_offset = int(encodings['offset_mapping'][0][token_end_index][0][1])
60
+ answer_offset = (answer_start_offset, answer_end_offset)
61
+ # Decoding
62
+ answer_text = tokenizer.decode(pred_ids) # text
63
+ del encodings
64
+ return {'answer_text':answer_text, 'answer_offset':answer_offset}
65
 
66
+
67
+ # Load fine-tuned MRC model
68
+ MODEL_PATH = "bespin-global/klue-bert-base-aihub-mrc"
69
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
70
+ model = AutoModelForQuestionAnswering.from_pretrained(MODEL_PATH)
71
 
72
  # gpu or cpu
73
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
74
  model.to(device)
75
  model.eval()
76
 
77
+ context = '''์• ํ”Œ M1(์˜์–ด: Apple M1)์€ ์• ํ”Œ์ด ์ž์‚ฌ์˜ ๋งคํ‚จํ† ์‹œ ์ปดํ“จํ„ฐ์šฉ์œผ๋กœ ์„ค๊ณ„ํ•œ ์ตœ์ดˆ์˜ ARM ๊ธฐ๋ฐ˜ SoC์ด๋‹ค.
78
+ 4์„ธ๋Œ€ ๋งฅ๋ถ ์—์–ด, 5์„ธ๋Œ€ ๋งฅ ๋ฏธ๋‹ˆ, 13์ธ์น˜ 5์„ธ๋Œ€ ๋งฅ๋ถ ํ”„๋กœ, 5์„ธ๋Œ€ ์•„์ดํŒจ๋“œ ํ”„๋กœ์— ์„ ๋ณด์˜€๋‹ค. 5๋‚˜๋…ธ๋ฏธํ„ฐ ๊ณต์ •์„ ์‚ฌ์šฉํ•˜์—ฌ ์ œ์กฐ๋œ ์ตœ์ดˆ์˜ ๊ฐœ์ธ์šฉ ์ปดํ“จํ„ฐ ์นฉ์ด๋‹ค.
79
+ ์• ํ”Œ์€ ์ €์ „๋ ฅ ์‹ค๋ฆฌ์ฝ˜์˜, ์„ธ๊ณ„์—์„œ ๊ฐ€์žฅ ๋น ๋ฅธ ARM ๊ธฐ๋ฐ˜์˜ ์ค‘์•™ ์ฒ˜๋ฆฌ ์žฅ์น˜(CPU) ์ฝ”์–ด, ๊ทธ๋ฆฌ๊ณ  ์„ธ๊ณ„ ์ตœ๊ณ ์˜ CPU ์„ฑ๋Šฅ ๋Œ€ ์™€ํŠธ๋ฅผ ๊ฐ–์ถ”๊ณ  ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•˜๊ณ  ์žˆ๋‹ค.'''
80
+ question = "์• ํ”Œ์ด m1์— ๋Œ€ํ•ด ์ฃผ์žฅํ•˜๋Š”๊ฑด ๋ญ์•ผ?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ qa_text_pair = {'context':context, 'question':question}
83
+ result = predict_answer(qa_text_pair)
84
+ print('Answer Text: ', result['answer_text']) # ์ €์ „๋ ฅ ์‹ค๋ฆฌ์ฝ˜์˜, ์„ธ๊ณ„์—์„œ ๊ฐ€์žฅ ๋น ๋ฅธ ARM ๊ธฐ๋ฐ˜์˜ ์ค‘์•™ ์ฒ˜๋ฆฌ ์žฅ์น˜ ( CPU ) ์ฝ”์–ด, ๊ทธ๋ฆฌ๊ณ  ์„ธ๊ณ„ ์ตœ๊ณ ์˜ CPU ์„ฑ๋Šฅ ๋Œ€ ์™€ํŠธ๋ฅผ ๊ฐ–์ถ”๊ณ  ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•˜๊ณ  ์žˆ๋‹ค.
85
+ print('Answer Offset: ', result['answer_offset']) # (159, 246)
86
  ```
87
 
88
  ## Citing & Authors