Create use.py
Browse files
use.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertForQuestionAnswering, BertTokenizerFast
|
| 4 |
+
|
| 5 |
+
# ── Config ───────────────────────────────────────────────────
|
| 6 |
+
MODEL_DIR = "model"
|
| 7 |
+
MAX_LENGTH = 384
|
| 8 |
+
DOC_STRIDE = 128
|
| 9 |
+
N_BEST = 20
|
| 10 |
+
MAX_ANS_LEN = 30
|
| 11 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
|
| 13 |
+
tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
|
| 14 |
+
model = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE)
|
| 15 |
+
model.eval()
|
| 16 |
+
print(f"✅ Model loaded on {DEVICE}")
|
| 17 |
+
|
| 18 |
+
def answer_question(question: str, context: str) -> dict:
|
| 19 |
+
inputs = tokenizer(
|
| 20 |
+
question,
|
| 21 |
+
context,
|
| 22 |
+
max_length=MAX_LENGTH,
|
| 23 |
+
truncation="only_second",
|
| 24 |
+
stride=DOC_STRIDE,
|
| 25 |
+
return_overflowing_tokens=True,
|
| 26 |
+
return_offsets_mapping=True,
|
| 27 |
+
padding="max_length",
|
| 28 |
+
return_tensors="pt",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
offset_mapping = inputs.pop("offset_mapping") # (n_chunks, seq_len, 2)
|
| 32 |
+
sample_map = inputs.pop("overflow_to_sample_mapping")
|
| 33 |
+
sequence_ids = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))]
|
| 34 |
+
|
| 35 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 36 |
+
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
outputs = model(**inputs)
|
| 39 |
+
|
| 40 |
+
start_logits = outputs.start_logits.cpu().numpy() # (n_chunks, seq_len)
|
| 41 |
+
end_logits = outputs.end_logits.cpu().numpy()
|
| 42 |
+
|
| 43 |
+
candidates = []
|
| 44 |
+
|
| 45 |
+
for chunk_idx in range(len(start_logits)):
|
| 46 |
+
offsets = offset_mapping[chunk_idx].numpy()
|
| 47 |
+
seq_ids = sequence_ids[chunk_idx]
|
| 48 |
+
|
| 49 |
+
s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1]
|
| 50 |
+
e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1]
|
| 51 |
+
|
| 52 |
+
for s in s_indexes:
|
| 53 |
+
for e in e_indexes:
|
| 54 |
+
if seq_ids[s] != 1 or seq_ids[e] != 1:
|
| 55 |
+
continue
|
| 56 |
+
if e < s or e - s + 1 > MAX_ANS_LEN:
|
| 57 |
+
continue
|
| 58 |
+
candidates.append({
|
| 59 |
+
"score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]),
|
| 60 |
+
"text": context[offsets[s][0]: offsets[e][1]],
|
| 61 |
+
"start": int(offsets[s][0]),
|
| 62 |
+
"end": int(offsets[e][1]),
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
if not candidates:
|
| 66 |
+
return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1}
|
| 67 |
+
|
| 68 |
+
best = max(candidates, key=lambda x: x["score"])
|
| 69 |
+
return {
|
| 70 |
+
"answer": best["text"],
|
| 71 |
+
"score": round(best["score"], 4),
|
| 72 |
+
"start": best["start"],
|
| 73 |
+
"end": best["end"],
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def ask(question: str, context: str):
|
| 78 |
+
result = answer_question(question, context)
|
| 79 |
+
print(f"❓ Question: {question}")
|
| 80 |
+
print(f"💬 Answer : {result['answer']}")
|
| 81 |
+
print(f"📊 Score : {result['score']}")
|
| 82 |
+
print(f"📍 Position: Char {result['start']}–{result['end']}")
|
| 83 |
+
print("-" * 60)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
ctx1 = """
|
| 88 |
+
The Amazon rainforest, also known as Amazonia, is a moist broadleaf
|
| 89 |
+
tropical rainforest in the Amazon biome that covers most of the Amazon
|
| 90 |
+
basin of South America. This basin encompasses 7,000,000 km² of which
|
| 91 |
+
5,500,000 km² are covered by the rainforest. The majority of the forest
|
| 92 |
+
is contained within Brazil, with 60% of the rainforest.
|
| 93 |
+
"""
|
| 94 |
+
ask("How much of the Amazon rainforest is in Brazil?", ctx1)
|
| 95 |
+
|
| 96 |
+
ctx2 = """
|
| 97 |
+
The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars
|
| 98 |
+
in Paris, France. It was constructed from 1887 to 1889 as the centerpiece
|
| 99 |
+
of the 1889 World's Fair. The tower is 330 metres tall and is the tallest
|
| 100 |
+
structure in Paris.
|
| 101 |
+
"""
|
| 102 |
+
ask("When was the Eiffel Tower built?", ctx2)
|
| 103 |
+
|
| 104 |
+
ctx3 = """
|
| 105 |
+
Python is a high-level, general-purpose programming language. Its design
|
| 106 |
+
philosophy emphasizes code readability with the use of significant indentation.
|
| 107 |
+
Python is dynamically typed and garbage-collected. It supports multiple
|
| 108 |
+
programming paradigms, including structured, object-oriented and functional
|
| 109 |
+
programming. It was created by Guido van Rossum and first released in 1991.
|
| 110 |
+
Python consistently ranks as one of the most popular programming languages.
|
| 111 |
+
It is widely used in data science, machine learning, web development, and
|
| 112 |
+
automation. The Python Package Index (PyPI) hosts hundreds of thousands of
|
| 113 |
+
third-party modules. The standard library is very extensive, offering tools
|
| 114 |
+
suited to many tasks.
|
| 115 |
+
""" * 3
|
| 116 |
+
|
| 117 |
+
ask("When was Python first released?", ctx3)
|
| 118 |
+
|
| 119 |
+
print("\n" + "=" * 60)
|
| 120 |
+
print("🎮 Interactive mode – stop with 'quit'")
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
|
| 123 |
+
context_interactive = input("📄 Input context:\n> ").strip()
|
| 124 |
+
while True:
|
| 125 |
+
q = input("\n❓ Question (or type 'quit'): ").strip()
|
| 126 |
+
if q.lower() == "quit":
|
| 127 |
+
print("👋 Bye.")
|
| 128 |
+
break
|
| 129 |
+
if not q:
|
| 130 |
+
continue
|
| 131 |
+
ask(q, context_interactive)
|