File size: 9,202 Bytes
a23be62
7769f4e
a23be62
 
 
 
 
 
 
 
793d718
a23be62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85b50eb
24c8832
 
 
 
 
 
 
 
 
c9f64e3
 
85b50eb
c9f64e3
 
 
 
 
85b50eb
b78f5bc
c9f64e3
 
a23be62
 
c9f64e3
a23be62
c9f64e3
a23be62
c9f64e3
a23be62
 
 
 
 
 
7769f4e
 
 
 
 
 
 
 
 
 
 
 
 
a23be62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7769f4e
 
 
 
a23be62
 
7769f4e
a23be62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7769f4e
a23be62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import torch
from transformers import BertTokenizer, AutoModelForTokenClassification, pipeline
import pickle # for saving and loading Python objects
from openai import OpenAI
import tiktoken
from transformers import AutoConfig, AutoTokenizer
import os
import torch.nn as nn
from transformers import AutoModel, AutoConfig

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# Define BiLSTMForTokenClassification Class


class BiLSTMForTokenClassification(nn.Module):
    """
        This model combines BERT embeddings with a Bidirectional LSTM (BiLSTM) for token-level classification
        tasks like Named Entity Recognition (NER).

        Args:
            pretrained_model_name_or_path: Name of the pre-trained BERT model to use (e.g., "bert-base-cased").
            num_labels: Number of different labels to predict.
            hidden_size: Dimension of the hidden states in the BiLSTM (default: 128).
            num_lstm_layers: Number of stacked BiLSTM layers (default: 1).
    """
    def __init__(self, model_name, num_labels, hidden_size=128, num_lstm_layers=1):
        super().__init__()
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name)

        # Freeze BERT embeddings
        for name, param in self.bert.named_parameters():
            if name.startswith("embeddings"):
                param.requires_grad = False

        self.bilstm = nn.LSTM(self.bert.config.hidden_size, hidden_size, num_layers=num_lstm_layers, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size * 2, num_labels)

    def forward(self, input_ids, attention_mask=None, labels=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        lstm_output, _ = self.bilstm(sequence_output)
        lstm_output = self.dropout(lstm_output)

        logits = self.classifier(lstm_output)
        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_labels)
            active_labels = torch.where(active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels))
            valid_mask = (active_labels >= 0) & (active_labels < self.num_labels)
            active_logits = active_logits[valid_mask]
            active_labels = active_labels[valid_mask]
            loss = loss_fct(active_logits, active_labels)

        return {'loss': loss, 'logits': logits}

def load_models():
    """
    Loads the custom BiLSTM NER model and tokenizer from local files.

    Returns:
        model (BiLSTMForTokenClassification): The loaded BiLSTM NER model.
        tokenizer (AutoTokenizer): The loaded tokenizer associated with the model.
        id2label_ner (dict): A dictionary mapping numerical label indices to NER tags.
    """

    model_dir = "./models/bilstm_ner"
    tokenizer_dir = "./models/tokenizer"

    id2label_ner = {
        0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve',
        6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe',
        12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'
    }

    config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
    config.id2label = id2label_ner
    config.num_labels = len(id2label_ner)

    model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels)
    model.config.id2label = id2label_ner
    model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')))
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True)

    return model, tokenizer, id2label_ner

# QA model
qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2')

# Function to extract information
def extract_information(text, bert_model, bilstm_model, ner_tokenizer, id2label_ner):
    """
    Extracts information from the given text using NER tags and generates 'Why' or 'How' questions with answers.

    Args:
        text: The input text string.
        bert_model: The pre-trained BERT model for token classification.
        bilstm_model: The BiLSTM model for NER tag prediction.
        ner_tokenizer: The tokenizer for the BiLSTM model.
        id2label_ner: A dictionary mapping numerical label indices to NER tags.

    Returns:
        A dictionary containing extracted 4W information, generated question, and answer.
    """
    extracted_info = {}

    ner_tags = predict_tags(text, bilstm_model, ner_tokenizer, id2label_ner)
    
    extracted_info.update(extract_4w_qa(text, ner_tags))
    
    qa_result = generate_why_or_how_question_and_answer(extracted_info, text)
    if qa_result:
        extracted_info.update(qa_result)
        prompt = f"Question: {qa_result['question']}\nContext: {text}\nAnswer:"
        extracted_info["Token Count"] = count_tokens(prompt)
        
    return extracted_info


def predict_tags(sentence, model, tokenizer, label_map):
    """
    Predicts NER tags for a given sentence using the specified model and tokenizer.

    Args:
        sentence (str): The input sentence.
        model (nn.Module): The NER model.
        tokenizer: The tokenizer used for the model.
        label_map (dict): A dictionary mapping numerical label indices to their corresponding tags.

    Returns:
        list: A list of predicted tags for each token in the sentence.
    """
    tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sentence)))
    inputs = tokenizer.encode(sentence, return_tensors='pt')

    outputs = model(inputs)
    logits = outputs['logits']
    predictions = torch.argmax(logits, dim=2)

    labels = [label_map.get(prediction.item(), "O") for prediction in predictions[0][1:-1]]
    return labels

def extract_4w_qa(sentence, ner_tags):
    """
    Extracts 4w (Who, What, When, Where) information from a sentence
    using NER tags and a question-answering model.

    Args:
        sentence: The input sentence as a string.
        ner_tags: A list of predicted NER tags for each token in the sentence.

    Returns:
        A dictionary where keys are 5W1H question words and values are the corresponding
        answers extracted from the sentence.
    """
    result = {}
    questions = {
        "B-per": "Who",
        "I-per": "Who",
        "B-geo": "Where",
        "I-geo": "Where",
        "B-org": "What organization",
        "I-org": "What organization",
        "B-tim": "When",
        "I-tim": "When",
        "B-art": "What art",
        "I-art": "What art",
        "B-eve": "What event",
        "I-eve": "What event",
        "B-nat": "What natural phenomenon",
        "I-nat": "What natural phenomenon",
    }

    for ner_tag, entity in zip(ner_tags, sentence.split()):  # Removed pos_tags
        if ner_tag in questions:
            question = f"{questions[ner_tag]} is {entity}?"  # Removed pos_tag
            answer = qa_model(question=question, context=sentence)["answer"]
            result[questions[ner_tag]] = answer

    return result

def count_tokens(text):
    """
    Counts the number of tokens in a text string using the tiktoken encoding for GPT-3.5 Turbo.

    Args:
        text: The input text string.

    Returns:
        The number of tokens in the text.
    """
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    return len(encoding.encode(text))

def generate_why_or_how_question_and_answer(extracted_info, sentence):
    """
    Generates a "Why" or "How" question based on the extracted 4W information and gets the answer using GPT-3.5.

    Args:
        extracted_info: A dictionary containing the extracted 4W information.
        sentence: The original sentence.

    Returns:
        A dictionary containing the generated question and its answer, or None if no relevant question can be generated.
    """

    prompt_template = """
    Given the following extracted information and the original sentence, generate a relevant "Why" or "How" question and provide a concise answer based on the given context.

    Extracted Information: {extracted_info}
    Sentence: {sentence}

    Question and Answer:
    """

    prompt = prompt_template.format(extracted_info=extracted_info, sentence=sentence)
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        max_tokens=150,
        stop=None,
        temperature=0.5,
    )

    question_and_answer = response.choices[0].message.content.strip()

    if question_and_answer:
        try:
            question, answer = question_and_answer.split("\n", 1)
            return {"question": question, "answer": answer}
        except ValueError:
            return None
    else:
        return None