LH-Tech-AI commited on
Commit
d80d74f
·
verified ·
1 Parent(s): a6a4d03

Create use.py

Browse files
Files changed (1) hide show
  1. use.py +131 -0
use.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from transformers import BertForQuestionAnswering, BertTokenizerFast
4
+
5
+ # ── Config ───────────────────────────────────────────────────
6
+ MODEL_DIR = "model"
7
+ MAX_LENGTH = 384
8
+ DOC_STRIDE = 128
9
+ N_BEST = 20
10
+ MAX_ANS_LEN = 30
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
14
+ model = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE)
15
+ model.eval()
16
+ print(f"✅ Model loaded on {DEVICE}")
17
+
18
+ def answer_question(question: str, context: str) -> dict:
19
+ inputs = tokenizer(
20
+ question,
21
+ context,
22
+ max_length=MAX_LENGTH,
23
+ truncation="only_second",
24
+ stride=DOC_STRIDE,
25
+ return_overflowing_tokens=True,
26
+ return_offsets_mapping=True,
27
+ padding="max_length",
28
+ return_tensors="pt",
29
+ )
30
+
31
+ offset_mapping = inputs.pop("offset_mapping") # (n_chunks, seq_len, 2)
32
+ sample_map = inputs.pop("overflow_to_sample_mapping")
33
+ sequence_ids = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))]
34
+
35
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
36
+
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+
40
+ start_logits = outputs.start_logits.cpu().numpy() # (n_chunks, seq_len)
41
+ end_logits = outputs.end_logits.cpu().numpy()
42
+
43
+ candidates = []
44
+
45
+ for chunk_idx in range(len(start_logits)):
46
+ offsets = offset_mapping[chunk_idx].numpy()
47
+ seq_ids = sequence_ids[chunk_idx]
48
+
49
+ s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1]
50
+ e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1]
51
+
52
+ for s in s_indexes:
53
+ for e in e_indexes:
54
+ if seq_ids[s] != 1 or seq_ids[e] != 1:
55
+ continue
56
+ if e < s or e - s + 1 > MAX_ANS_LEN:
57
+ continue
58
+ candidates.append({
59
+ "score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]),
60
+ "text": context[offsets[s][0]: offsets[e][1]],
61
+ "start": int(offsets[s][0]),
62
+ "end": int(offsets[e][1]),
63
+ })
64
+
65
+ if not candidates:
66
+ return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1}
67
+
68
+ best = max(candidates, key=lambda x: x["score"])
69
+ return {
70
+ "answer": best["text"],
71
+ "score": round(best["score"], 4),
72
+ "start": best["start"],
73
+ "end": best["end"],
74
+ }
75
+
76
+
77
+ def ask(question: str, context: str):
78
+ result = answer_question(question, context)
79
+ print(f"❓ Question: {question}")
80
+ print(f"💬 Answer : {result['answer']}")
81
+ print(f"📊 Score : {result['score']}")
82
+ print(f"📍 Position: Char {result['start']}–{result['end']}")
83
+ print("-" * 60)
84
+
85
+
86
+
87
+ ctx1 = """
88
+ The Amazon rainforest, also known as Amazonia, is a moist broadleaf
89
+ tropical rainforest in the Amazon biome that covers most of the Amazon
90
+ basin of South America. This basin encompasses 7,000,000 km² of which
91
+ 5,500,000 km² are covered by the rainforest. The majority of the forest
92
+ is contained within Brazil, with 60% of the rainforest.
93
+ """
94
+ ask("How much of the Amazon rainforest is in Brazil?", ctx1)
95
+
96
+ ctx2 = """
97
+ The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars
98
+ in Paris, France. It was constructed from 1887 to 1889 as the centerpiece
99
+ of the 1889 World's Fair. The tower is 330 metres tall and is the tallest
100
+ structure in Paris.
101
+ """
102
+ ask("When was the Eiffel Tower built?", ctx2)
103
+
104
+ ctx3 = """
105
+ Python is a high-level, general-purpose programming language. Its design
106
+ philosophy emphasizes code readability with the use of significant indentation.
107
+ Python is dynamically typed and garbage-collected. It supports multiple
108
+ programming paradigms, including structured, object-oriented and functional
109
+ programming. It was created by Guido van Rossum and first released in 1991.
110
+ Python consistently ranks as one of the most popular programming languages.
111
+ It is widely used in data science, machine learning, web development, and
112
+ automation. The Python Package Index (PyPI) hosts hundreds of thousands of
113
+ third-party modules. The standard library is very extensive, offering tools
114
+ suited to many tasks.
115
+ """ * 3
116
+
117
+ ask("When was Python first released?", ctx3)
118
+
119
+ print("\n" + "=" * 60)
120
+ print("🎮 Interactive mode – stop with 'quit'")
121
+ print("=" * 60)
122
+
123
+ context_interactive = input("📄 Input context:\n> ").strip()
124
+ while True:
125
+ q = input("\n❓ Question (or type 'quit'): ").strip()
126
+ if q.lower() == "quit":
127
+ print("👋 Bye.")
128
+ break
129
+ if not q:
130
+ continue
131
+ ask(q, context_interactive)