import torch from transformers import BertForTokenClassification from .config_train import device, model_load_path, tokenizer from .DataProcessing import read_input from .load_data import sorted_tags class Key_Ner_Predictor: def __init__(self, model_path, tokenizer, device, tag_map): """ Initialize the Key_Ner_Predictor with the model, tokenizer, and device. Args: model_path (str): Path to the pre-trained model. tokenizer (BertTokenizer): Tokenizer to process input sentences. device (torch.device): Device to run the model on. tag_map (Dict[int, str]): Mapping of indices to tags. """ self.model = BertForTokenClassification.from_pretrained(model_path).to(device) self.tokenizer = tokenizer self.device = device self.tag_map = tag_map def predict(self, sentence): """ Predict the tags for each token in the given sentence. Args: sentence (str): Input sentence to predict. Returns: Tuple[str, List[str]]: The original sentence and its predicted tags. """ # Process the sentence sentence = read_input(sentence) # Tokenize the sentence input_ids = self.tokenizer.encode(sentence, return_tensors="pt").to(self.device) # Create attention masks attention_masks = (input_ids != self.tokenizer.pad_token_id).float().to(self.device) # Set model to evaluation mode self.model.eval() with torch.no_grad(): # Forward pass outputs = self.model(input_ids, token_type_ids=None, attention_mask=attention_masks) logits = outputs.logits # Get predicted tags for each token in the sentence predicted_tags = torch.argmax(logits, dim=2).cpu().numpy()[0] # Map indices to tags predicted_tags = [self.tag_map[idx] for idx in predicted_tags] predicted_tags = set(predicted_tags) predicted_tags.remove('') predicted_tags = list(predicted_tags) for index in range(len(predicted_tags)): predicted_tags[index] = predicted_tags[index].replace(" ", "_") return self.tokenizer.decode(input_ids[0], skip_special_tokens=True), predicted_tags # Initialize the Key_Ner_Predictor predictor = Key_Ner_Predictor( model_path=model_load_path, tokenizer=tokenizer, device=device, tag_map=dict(enumerate(sorted_tags)) ) # # Define the sentence to predict # sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình" # # Get the prediction # original_sentence, predicted_tags = predictor.predict(sentence) # # Print the sentence and its predicted tags # print("Sentence:", original_sentence) # print("Predicted Tags:", predicted_tags)