File size: 9,597 Bytes
a23be62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import torch
from transformers import BertTokenizer, BertForTokenClassification, pipeline
import pickle # for saving and loading Python objects
from openai import OpenAI
import tiktoken
from transformers import AutoConfig, AutoTokenizer
import os
import torch.nn as nn
from transformers import AutoModel, AutoConfig

client = OpenAI(api_key="sk-proj-K2n4UpzlAKfw464kITLHT3BlbkFJfXtLIl4Ejhn1KHQOjnTq")

# Define BiLSTMForTokenClassification Class


class BiLSTMForTokenClassification(nn.Module):
    """
        This model combines BERT embeddings with a Bidirectional LSTM (BiLSTM) for token-level classification
        tasks like Named Entity Recognition (NER).

        Args:
            pretrained_model_name_or_path: Name of the pre-trained BERT model to use (e.g., "bert-base-cased").
            num_labels: Number of different labels to predict.
            hidden_size: Dimension of the hidden states in the BiLSTM (default: 128).
            num_lstm_layers: Number of stacked BiLSTM layers (default: 1).
    """
    def __init__(self, model_name, num_labels, hidden_size=128, num_lstm_layers=1):
        super().__init__()
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name)

        # Freeze BERT embeddings
        for name, param in self.bert.named_parameters():
            if name.startswith("embeddings"):
                param.requires_grad = False

        self.bilstm = nn.LSTM(self.bert.config.hidden_size, hidden_size, num_layers=num_lstm_layers, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size * 2, num_labels)

    def forward(self, input_ids, attention_mask=None, labels=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        lstm_output, _ = self.bilstm(sequence_output)
        lstm_output = self.dropout(lstm_output)

        logits = self.classifier(lstm_output)
        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_labels)
            active_labels = torch.where(active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels))
            valid_mask = (active_labels >= 0) & (active_labels < self.num_labels)
            active_logits = active_logits[valid_mask]
            active_labels = active_labels[valid_mask]
            loss = loss_fct(active_logits, active_labels)

        return {'loss': loss, 'logits': logits}

# Load custom BiLSTM and pre-trained BERT
def load_models():
    bert_model = BertForTokenClassification.from_pretrained("joyinning/chatbot-info-extraction/models/bert-model.pkl")
    bert_model.eval()

    with open('joyinning/chatbot-info-extraction/models/bilstm-model.pkl', 'rb') as f:
        bilstm_model = pickle.load(f)

    return bert_model, bilstm_model

def load_custom_model(model_dir, tokenizer_dir, id2label):
    config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
    config.id2label = id2label
    config.num_labels = len(id2label)

    model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels)
    model.config.id2label = id2label
    model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')))
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True)

    return model, tokenizer

ner_model_dir = "models/bilstm_ner"
tokenizer_dir = "models/tokenizer"
id2label_ner = {0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve', 6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe', 12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'}
ner_model, ner_tokenizer = load_custom_model(ner_model_dir, tokenizer_dir, id2label_ner)

# QA model
qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2')

# Function to extract information
def extract_information(text, bert_model, bilstm_model, ner_tokenizer, id2label_ner):
    extracted_info = {}

    ner_tags = predict_tags(text, bilstm_model, ner_tokenizer, id2label_ner)
    
    extracted_info.update(extract_4w_qa(text, ner_tags))
    
    qa_result = generate_why_or_how_question_and_answer(extracted_info, text)
    if qa_result:
        extracted_info.update(qa_result)
        prompt = f"Question: {qa_result['question']}\nContext: {text}\nAnswer:"
        extracted_info["Token Count"] = count_tokens(prompt)
        
    return extracted_info


def predict_tags(sentence, model, tokenizer, label_map):
    """
    Predicts NER tags for a given sentence using the specified model and tokenizer.

    Args:
        sentence: The input sentence as a string.
        model: The pre-trained model (BiLSTM) for tag prediction.
        tokenizer: The tokenizer used for converting the sentence into tokens.
        label_map: A dictionary mapping numerical label indices to their corresponding tags.

    Returns:
        A list of predicted tags for each token in the sentence.
    """
    tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sentence)))
    inputs = tokenizer.encode(sentence, return_tensors='pt')

    outputs = model(inputs)
    logits = outputs['logits']
    predictions = torch.argmax(logits, dim=2)

    labels = [label_map.get(prediction.item(), "O") for prediction in predictions[0][1:-1]]
    return labels

def extract_4w_qa(sentence, ner_tags):
    """
    Extracts 4w (Who, What, When, Where) information from a sentence
    using NER tags and a question-answering model.

    Args:
        sentence: The input sentence as a string.
        ner_tags: A list of predicted NER tags for each token in the sentence.

    Returns:
        A dictionary where keys are 5W1H question words and values are the corresponding
        answers extracted from the sentence.
    """
    result = {}
    questions = {
        "B-per": "Who",
        "I-per": "Who",
        "B-geo": "Where",
        "I-geo": "Where",
        "B-org": "What organization",
        "I-org": "What organization",
        "B-tim": "When",
        "I-tim": "When",
        "B-art": "What art",
        "I-art": "What art",
        "B-eve": "What event",
        "I-eve": "What event",
        "B-nat": "What natural phenomenon",
        "I-nat": "What natural phenomenon",
    }

    for ner_tag, entity in zip(ner_tags, sentence.split()):  # Removed pos_tags
        if ner_tag in questions:
            question = f"{questions[ner_tag]} is {entity}?"  # Removed pos_tag
            answer = qa_model(question=question, context=sentence)["answer"]
            result[questions[ner_tag]] = answer

    return result

def count_tokens(text):
    """
    Counts the number of tokens in a text string using the tiktoken encoding for GPT-3.5 Turbo.

    Args:
        text: The input text string.

    Returns:
        The number of tokens in the text.
    """
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct")
    return len(encoding.encode(text))

def generate_why_or_how_question_and_answer(extracted_info, sentence):
    """
    Generates a "Why" or "How" question based on the extracted 4W information and gets the answer using GPT-3.5.

    Args:
        extracted_info: A dictionary containing the extracted 4W information.
        sentence: The original sentence.

    Returns:
        A dictionary containing the generated question and its answer, or None if no relevant question can be generated.
    """

    prompt_template = """
    Given the following extracted information and the original sentence, generate a relevant "Why" or "How" question and provide a concise answer based on the given context.

    Extracted Information: {extracted_info}
    Sentence: {sentence}

    Question and Answer:
    """

    prompt = prompt_template.format(extracted_info=extracted_info, sentence=sentence)
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        max_tokens=150,
        stop=None,
        temperature=0.5,
    )

    question_and_answer = response.choices[0].message.content.strip()

    if question_and_answer:
        try:
            question, answer = question_and_answer.split("\n", 1)
            return {"question": question, "answer": answer}
        except ValueError:
            return None
    else:
        return None

def get_why_or_how_answer(question, context):
    """
    Queries OpenAI's GPT-3.5 model to generate an answer for a given question based on the provided context.

    Args:
        question (str): The question to be answered.
        context (str): The text context from which the answer should be extracted.

    Returns:
        str: The generated answer from GPT-3.5.
    """
    prompt = f"Question: {question}\nContext: {context}\nAnswer:"

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        max_tokens=150,
        stop=None,
        temperature=0.5,
    )

    return response.choices[0].text.strip()