Spaces:
Sleeping
Sleeping
import torch | |
from transformers import BertTokenizer, BertForTokenClassification, pipeline | |
import pickle # for saving and loading Python objects | |
from openai import OpenAI | |
import tiktoken | |
from transformers import AutoConfig, AutoTokenizer | |
import os | |
import torch.nn as nn | |
from transformers import AutoModel, AutoConfig | |
client = OpenAI(api_key="sk-proj-K2n4UpzlAKfw464kITLHT3BlbkFJfXtLIl4Ejhn1KHQOjnTq") | |
# Define BiLSTMForTokenClassification Class | |
class BiLSTMForTokenClassification(nn.Module): | |
""" | |
This model combines BERT embeddings with a Bidirectional LSTM (BiLSTM) for token-level classification | |
tasks like Named Entity Recognition (NER). | |
Args: | |
pretrained_model_name_or_path: Name of the pre-trained BERT model to use (e.g., "bert-base-cased"). | |
num_labels: Number of different labels to predict. | |
hidden_size: Dimension of the hidden states in the BiLSTM (default: 128). | |
num_lstm_layers: Number of stacked BiLSTM layers (default: 1). | |
""" | |
def __init__(self, model_name, num_labels, hidden_size=128, num_lstm_layers=1): | |
super().__init__() | |
self.num_labels = num_labels | |
self.config = AutoConfig.from_pretrained(model_name) | |
self.bert = AutoModel.from_pretrained(model_name) | |
# Freeze BERT embeddings | |
for name, param in self.bert.named_parameters(): | |
if name.startswith("embeddings"): | |
param.requires_grad = False | |
self.bilstm = nn.LSTM(self.bert.config.hidden_size, hidden_size, num_layers=num_lstm_layers, bidirectional=True, batch_first=True) | |
self.dropout = nn.Dropout(0.1) | |
self.classifier = nn.Linear(hidden_size * 2, num_labels) | |
def forward(self, input_ids, attention_mask=None, labels=None): | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
sequence_output = outputs[0] | |
lstm_output, _ = self.bilstm(sequence_output) | |
lstm_output = self.dropout(lstm_output) | |
logits = self.classifier(lstm_output) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.CrossEntropyLoss() | |
active_loss = attention_mask.view(-1) == 1 | |
active_logits = logits.view(-1, self.num_labels) | |
active_labels = torch.where(active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)) | |
valid_mask = (active_labels >= 0) & (active_labels < self.num_labels) | |
active_logits = active_logits[valid_mask] | |
active_labels = active_labels[valid_mask] | |
loss = loss_fct(active_logits, active_labels) | |
return {'loss': loss, 'logits': logits} | |
# Load custom BiLSTM and pre-trained BERT | |
def load_models(): | |
bert_model = BertForTokenClassification.from_pretrained("joyinning/chatbot-info-extraction/models/bert-model.pkl") | |
bert_model.eval() | |
with open('joyinning/chatbot-info-extraction/models/bilstm-model.pkl', 'rb') as f: | |
bilstm_model = pickle.load(f) | |
return bert_model, bilstm_model | |
def load_custom_model(model_dir, tokenizer_dir, id2label): | |
config = AutoConfig.from_pretrained(model_dir, local_files_only=True) | |
config.id2label = id2label | |
config.num_labels = len(id2label) | |
model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels) | |
model.config.id2label = id2label | |
model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu'))) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True) | |
return model, tokenizer | |
ner_model_dir = "models/bilstm_ner" | |
tokenizer_dir = "models/tokenizer" | |
id2label_ner = {0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve', 6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe', 12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'} | |
ner_model, ner_tokenizer = load_custom_model(ner_model_dir, tokenizer_dir, id2label_ner) | |
# QA model | |
qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2') | |
# Function to extract information | |
def extract_information(text, bert_model, bilstm_model, ner_tokenizer, id2label_ner): | |
extracted_info = {} | |
ner_tags = predict_tags(text, bilstm_model, ner_tokenizer, id2label_ner) | |
extracted_info.update(extract_4w_qa(text, ner_tags)) | |
qa_result = generate_why_or_how_question_and_answer(extracted_info, text) | |
if qa_result: | |
extracted_info.update(qa_result) | |
prompt = f"Question: {qa_result['question']}\nContext: {text}\nAnswer:" | |
extracted_info["Token Count"] = count_tokens(prompt) | |
return extracted_info | |
def predict_tags(sentence, model, tokenizer, label_map): | |
""" | |
Predicts NER tags for a given sentence using the specified model and tokenizer. | |
Args: | |
sentence: The input sentence as a string. | |
model: The pre-trained model (BiLSTM) for tag prediction. | |
tokenizer: The tokenizer used for converting the sentence into tokens. | |
label_map: A dictionary mapping numerical label indices to their corresponding tags. | |
Returns: | |
A list of predicted tags for each token in the sentence. | |
""" | |
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sentence))) | |
inputs = tokenizer.encode(sentence, return_tensors='pt') | |
outputs = model(inputs) | |
logits = outputs['logits'] | |
predictions = torch.argmax(logits, dim=2) | |
labels = [label_map.get(prediction.item(), "O") for prediction in predictions[0][1:-1]] | |
return labels | |
def extract_4w_qa(sentence, ner_tags): | |
""" | |
Extracts 4w (Who, What, When, Where) information from a sentence | |
using NER tags and a question-answering model. | |
Args: | |
sentence: The input sentence as a string. | |
ner_tags: A list of predicted NER tags for each token in the sentence. | |
Returns: | |
A dictionary where keys are 5W1H question words and values are the corresponding | |
answers extracted from the sentence. | |
""" | |
result = {} | |
questions = { | |
"B-per": "Who", | |
"I-per": "Who", | |
"B-geo": "Where", | |
"I-geo": "Where", | |
"B-org": "What organization", | |
"I-org": "What organization", | |
"B-tim": "When", | |
"I-tim": "When", | |
"B-art": "What art", | |
"I-art": "What art", | |
"B-eve": "What event", | |
"I-eve": "What event", | |
"B-nat": "What natural phenomenon", | |
"I-nat": "What natural phenomenon", | |
} | |
for ner_tag, entity in zip(ner_tags, sentence.split()): # Removed pos_tags | |
if ner_tag in questions: | |
question = f"{questions[ner_tag]} is {entity}?" # Removed pos_tag | |
answer = qa_model(question=question, context=sentence)["answer"] | |
result[questions[ner_tag]] = answer | |
return result | |
def count_tokens(text): | |
""" | |
Counts the number of tokens in a text string using the tiktoken encoding for GPT-3.5 Turbo. | |
Args: | |
text: The input text string. | |
Returns: | |
The number of tokens in the text. | |
""" | |
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct") | |
return len(encoding.encode(text)) | |
def generate_why_or_how_question_and_answer(extracted_info, sentence): | |
""" | |
Generates a "Why" or "How" question based on the extracted 4W information and gets the answer using GPT-3.5. | |
Args: | |
extracted_info: A dictionary containing the extracted 4W information. | |
sentence: The original sentence. | |
Returns: | |
A dictionary containing the generated question and its answer, or None if no relevant question can be generated. | |
""" | |
prompt_template = """ | |
Given the following extracted information and the original sentence, generate a relevant "Why" or "How" question and provide a concise answer based on the given context. | |
Extracted Information: {extracted_info} | |
Sentence: {sentence} | |
Question and Answer: | |
""" | |
prompt = prompt_template.format(extracted_info=extracted_info, sentence=sentence) | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt}, | |
], | |
max_tokens=150, | |
stop=None, | |
temperature=0.5, | |
) | |
question_and_answer = response.choices[0].message.content.strip() | |
if question_and_answer: | |
try: | |
question, answer = question_and_answer.split("\n", 1) | |
return {"question": question, "answer": answer} | |
except ValueError: | |
return None | |
else: | |
return None | |
def get_why_or_how_answer(question, context): | |
""" | |
Queries OpenAI's GPT-3.5 model to generate an answer for a given question based on the provided context. | |
Args: | |
question (str): The question to be answered. | |
context (str): The text context from which the answer should be extracted. | |
Returns: | |
str: The generated answer from GPT-3.5. | |
""" | |
prompt = f"Question: {question}\nContext: {context}\nAnswer:" | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt}, | |
], | |
max_tokens=150, | |
stop=None, | |
temperature=0.5, | |
) | |
return response.choices[0].text.strip() | |