real-jiakai commited on
Commit
b9780e3
1 Parent(s): c933d8d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -17
README.md CHANGED
@@ -95,23 +95,64 @@ model_name = "real-jiakai/bert-base-chinese-finetuned-squadv2"
95
  tokenizer = AutoTokenizer.from_pretrained(model_name)
96
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
97
 
98
- # Prepare the inputs
99
- question = "your_question"
100
- context = "your_context"
101
- inputs = tokenizer(
102
- question,
103
- context,
104
- add_special_tokens=True,
105
- return_tensors="pt"
106
- )
107
-
108
- # Get the answer
109
- start_scores, end_scores = model(**inputs)
110
- start_index = torch.argmax(start_scores)
111
- end_index = torch.argmax(end_scores)
112
- answer = tokenizer.convert_tokens_to_string(
113
- tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index+1])
114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ```
116
 
117
  ## Limitations and Bias
 
95
  tokenizer = AutoTokenizer.from_pretrained(model_name)
96
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
97
 
98
+ def get_answer(question, context, threshold=0.0):
99
+ # Tokenize input with maximum sequence length of 384
100
+ inputs = tokenizer(
101
+ question,
102
+ context,
103
+ return_tensors="pt",
104
+ max_length=384,
105
+ truncation=True
106
+ )
107
+
108
+ with torch.no_grad():
109
+ outputs = model(**inputs)
110
+ start_logits = outputs.start_logits[0]
111
+ end_logits = outputs.end_logits[0]
112
+
113
+ # Calculate null score (score for predicting no answer)
114
+ null_score = start_logits[0].item() + end_logits[0].item()
115
+
116
+ # Find the best non-null answer, excluding [CLS] position
117
+ # Set logits at [CLS] position to negative infinity
118
+ start_logits[0] = float('-inf')
119
+ end_logits[0] = float('-inf')
120
+
121
+ start_idx = torch.argmax(start_logits)
122
+ end_idx = torch.argmax(end_logits)
123
+
124
+ # Ensure end_idx is not less than start_idx
125
+ if end_idx < start_idx:
126
+ end_idx = start_idx
127
+
128
+ answer_score = start_logits[start_idx].item() + end_logits[end_idx].item()
129
+
130
+ # If null score is higher (beyond threshold), return "no answer"
131
+ if null_score - answer_score > threshold:
132
+ return "Question cannot be answered based on the given context."
133
+
134
+ # Otherwise, return the extracted answer
135
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
136
+ answer = tokenizer.convert_tokens_to_string(tokens[start_idx:end_idx+1])
137
+
138
+ # Check if answer is empty or contains only special tokens
139
+ if not answer.strip() or answer.strip() in ['[CLS]', '[SEP]']:
140
+ return "Question cannot be answered based on the given context."
141
+
142
+ return answer.strip()
143
+
144
+ questions = [
145
+ "本届第十五届珠海航展的亮点和主要展示内容是什么?",
146
+ "珠海杀人案发生地点?"
147
+ ]
148
+
149
+ context = '第十五届中国国际航空航天博览会(珠海航展)于2024年11月12日至17日在珠海国际航展中心举行。本届航展吸引了来自47个国家和地区的超过890家企业参展,展示了涵盖"陆、海、空、天、电、网"全领域的高精尖展品。其中,备受瞩目的中国空军"八一"飞行表演队和"红鹰"飞行表演队,以及俄罗斯"勇士"飞行表演队同台献技,为观众呈现了精彩的飞行表演。此外,本届航展还首次开辟了无人机、无人船演示区,展示了多款前沿科技产品。'
150
+
151
+ for question in questions:
152
+ answer = get_answer(question, context)
153
+ print(f"问题: {question}")
154
+ print(f"答案: {answer}")
155
+ print("-" * 50)
156
  ```
157
 
158
  ## Limitations and Bias