import torch from transformers import BertTokenizer, BertForTokenClassification, pipeline import pickle # for saving and loading Python objects from openai import OpenAI import tiktoken from transformers import AutoConfig, AutoTokenizer import os import torch.nn as nn from transformers import AutoModel, AutoConfig client = OpenAI(api_key="sk-proj-K2n4UpzlAKfw464kITLHT3BlbkFJfXtLIl4Ejhn1KHQOjnTq") # Define BiLSTMForTokenClassification Class class BiLSTMForTokenClassification(nn.Module): """ This model combines BERT embeddings with a Bidirectional LSTM (BiLSTM) for token-level classification tasks like Named Entity Recognition (NER). Args: pretrained_model_name_or_path: Name of the pre-trained BERT model to use (e.g., "bert-base-cased"). num_labels: Number of different labels to predict. hidden_size: Dimension of the hidden states in the BiLSTM (default: 128). num_lstm_layers: Number of stacked BiLSTM layers (default: 1). """ def __init__(self, model_name, num_labels, hidden_size=128, num_lstm_layers=1): super().__init__() self.num_labels = num_labels self.config = AutoConfig.from_pretrained(model_name) self.bert = AutoModel.from_pretrained(model_name) # Freeze BERT embeddings for name, param in self.bert.named_parameters(): if name.startswith("embeddings"): param.requires_grad = False self.bilstm = nn.LSTM(self.bert.config.hidden_size, hidden_size, num_layers=num_lstm_layers, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(hidden_size * 2, num_labels) def forward(self, input_ids, attention_mask=None, labels=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) sequence_output = outputs[0] lstm_output, _ = self.bilstm(sequence_output) lstm_output = self.dropout(lstm_output) logits = self.classifier(lstm_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where(active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)) valid_mask = (active_labels >= 0) & (active_labels < self.num_labels) active_logits = active_logits[valid_mask] active_labels = active_labels[valid_mask] loss = loss_fct(active_logits, active_labels) return {'loss': loss, 'logits': logits} # Load custom BiLSTM and pre-trained BERT def load_models(): bert_model = BertForTokenClassification.from_pretrained("joyinning/chatbot-info-extraction/models/bert-model.pkl") bert_model.eval() with open('joyinning/chatbot-info-extraction/models/bilstm-model.pkl', 'rb') as f: bilstm_model = pickle.load(f) return bert_model, bilstm_model def load_custom_model(model_dir, tokenizer_dir, id2label): config = AutoConfig.from_pretrained(model_dir, local_files_only=True) config.id2label = id2label config.num_labels = len(id2label) model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels) model.config.id2label = id2label model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu'))) tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True) return model, tokenizer ner_model_dir = "models/bilstm_ner" tokenizer_dir = "models/tokenizer" id2label_ner = {0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve', 6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe', 12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'} ner_model, ner_tokenizer = load_custom_model(ner_model_dir, tokenizer_dir, id2label_ner) # QA model qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2') # Function to extract information def extract_information(text, bert_model, bilstm_model, ner_tokenizer, id2label_ner): extracted_info = {} ner_tags = predict_tags(text, bilstm_model, ner_tokenizer, id2label_ner) extracted_info.update(extract_4w_qa(text, ner_tags)) qa_result = generate_why_or_how_question_and_answer(extracted_info, text) if qa_result: extracted_info.update(qa_result) prompt = f"Question: {qa_result['question']}\nContext: {text}\nAnswer:" extracted_info["Token Count"] = count_tokens(prompt) return extracted_info def predict_tags(sentence, model, tokenizer, label_map): """ Predicts NER tags for a given sentence using the specified model and tokenizer. Args: sentence: The input sentence as a string. model: The pre-trained model (BiLSTM) for tag prediction. tokenizer: The tokenizer used for converting the sentence into tokens. label_map: A dictionary mapping numerical label indices to their corresponding tags. Returns: A list of predicted tags for each token in the sentence. """ tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sentence))) inputs = tokenizer.encode(sentence, return_tensors='pt') outputs = model(inputs) logits = outputs['logits'] predictions = torch.argmax(logits, dim=2) labels = [label_map.get(prediction.item(), "O") for prediction in predictions[0][1:-1]] return labels def extract_4w_qa(sentence, ner_tags): """ Extracts 4w (Who, What, When, Where) information from a sentence using NER tags and a question-answering model. Args: sentence: The input sentence as a string. ner_tags: A list of predicted NER tags for each token in the sentence. Returns: A dictionary where keys are 5W1H question words and values are the corresponding answers extracted from the sentence. """ result = {} questions = { "B-per": "Who", "I-per": "Who", "B-geo": "Where", "I-geo": "Where", "B-org": "What organization", "I-org": "What organization", "B-tim": "When", "I-tim": "When", "B-art": "What art", "I-art": "What art", "B-eve": "What event", "I-eve": "What event", "B-nat": "What natural phenomenon", "I-nat": "What natural phenomenon", } for ner_tag, entity in zip(ner_tags, sentence.split()): # Removed pos_tags if ner_tag in questions: question = f"{questions[ner_tag]} is {entity}?" # Removed pos_tag answer = qa_model(question=question, context=sentence)["answer"] result[questions[ner_tag]] = answer return result def count_tokens(text): """ Counts the number of tokens in a text string using the tiktoken encoding for GPT-3.5 Turbo. Args: text: The input text string. Returns: The number of tokens in the text. """ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct") return len(encoding.encode(text)) def generate_why_or_how_question_and_answer(extracted_info, sentence): """ Generates a "Why" or "How" question based on the extracted 4W information and gets the answer using GPT-3.5. Args: extracted_info: A dictionary containing the extracted 4W information. sentence: The original sentence. Returns: A dictionary containing the generated question and its answer, or None if no relevant question can be generated. """ prompt_template = """ Given the following extracted information and the original sentence, generate a relevant "Why" or "How" question and provide a concise answer based on the given context. Extracted Information: {extracted_info} Sentence: {sentence} Question and Answer: """ prompt = prompt_template.format(extracted_info=extracted_info, sentence=sentence) response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], max_tokens=150, stop=None, temperature=0.5, ) question_and_answer = response.choices[0].message.content.strip() if question_and_answer: try: question, answer = question_and_answer.split("\n", 1) return {"question": question, "answer": answer} except ValueError: return None else: return None def get_why_or_how_answer(question, context): """ Queries OpenAI's GPT-3.5 model to generate an answer for a given question based on the provided context. Args: question (str): The question to be answered. context (str): The text context from which the answer should be extracted. Returns: str: The generated answer from GPT-3.5. """ prompt = f"Question: {question}\nContext: {context}\nAnswer:" response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], max_tokens=150, stop=None, temperature=0.5, ) return response.choices[0].text.strip()