real-jiakai/chinese-squadv2
Viewer • Updated • 100k • 92 • 3
How to use real-jiakai/bert-base-chinese-finetuned-squadv2 with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("question-answering", model="real-jiakai/bert-base-chinese-finetuned-squadv2") # Load model directly
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("real-jiakai/bert-base-chinese-finetuned-squadv2")
model = AutoModelForQuestionAnswering.from_pretrained("real-jiakai/bert-base-chinese-finetuned-squadv2")This model is a fine-tuned version of bert-base-chinese on the Chinese SQuAD v2.0 dataset.
This model is designed for Chinese question answering tasks, specifically for extractive QA where the answer must be extracted from a given context paragraph. It can handle both answerable and unanswerable questions, following the SQuAD v2.0 format.
Key features:
The model was trained on the Chinese SQuAD v2.0 dataset, which contains:
Training Set:
Validation Set:
Final evaluation metrics:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "real-jiakai/bert-base-chinese-finetuned-squadv2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
def get_answer(question, context, threshold=0.0):
# Tokenize input with maximum sequence length of 384
inputs = tokenizer(
question,
context,
return_tensors="pt",
max_length=384,
truncation=True
)
with torch.no_grad():
outputs = model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
# Calculate null score (score for predicting no answer)
null_score = start_logits[0].item() + end_logits[0].item()
# Find the best non-null answer, excluding [CLS] position
# Set logits at [CLS] position to negative infinity
start_logits[0] = float('-inf')
end_logits[0] = float('-inf')
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
# Ensure end_idx is not less than start_idx
if end_idx < start_idx:
end_idx = start_idx
answer_score = start_logits[start_idx].item() + end_logits[end_idx].item()
# If null score is higher (beyond threshold), return "no answer"
if null_score - answer_score > threshold:
return "Question cannot be answered based on the given context."
# Otherwise, return the extracted answer
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
answer = tokenizer.convert_tokens_to_string(tokens[start_idx:end_idx+1])
# Check if answer is empty or contains only special tokens
if not answer.strip() or answer.strip() in ['[CLS]', '[SEP]']:
return "Question cannot be answered based on the given context."
return answer.strip()
questions = [
"本届第十五届珠海航展的亮点和主要展示内容是什么?",
"珠海杀人案发生地点?"
]
context = '第十五届中国国际航空航天博览会(珠海航展)于2024年11月12日至17日在珠海国际航展中心举行。本届航展吸引了来自47个国家和地区的超过890家企业参展,展示了涵盖"陆、海、空、天、电、网"全领域的高精尖展品。其中,备受瞩目的中国空军"八一"飞行表演队和"红鹰"飞行表演队,以及俄罗斯"勇士"飞行表演队同台献技,为观众呈现了精彩的飞行表演。此外,本届航展还首次开辟了无人机、无人船演示区,展示了多款前沿科技产品。'
for question in questions:
answer = get_answer(question, context)
print(f"问题: {question}")
print(f"答案: {answer}")
print("-" * 50)
The model shows significant performance disparity between answerable and unanswerable questions, which might indicate:
Users should be aware that:
Base model
google-bert/bert-base-chinese