from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification from transformers import pipeline from flair.data import Sentence from flair.models import SequenceTagger import pickle class Models: def pickle_it(self, obj, file_name): with open(f'{file_name}.pickle', 'wb') as f: pickle.dump(obj, f) def unpickle_it(self, file_name): with open(f'{file_name}.pickle', 'rb') as f: return pickle.load(f) def load_trained_models(self, pickle=False): #NER (dates) tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates") model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates") self.ner_dates = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple") #Zero Shot Classification # self.zero_shot_classifier = pipeline("zero-shot-classification", model='facebook/bart-large-mnli') self.zero_shot_classifier = pipeline("zero-shot-classification", model='valhalla/distilbart-mnli-12-6') # Ner tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") self.ner = pipeline('ner', model=model, tokenizer=tokenizer, grouped_entities=True) # Pos Tagging self.tagger = SequenceTagger.load("flair/pos-english-fast") if pickle: self.pickle_models() return self.ner, self.ner_dates, self.zero_shot_classifier, self.tagger def pickle_models(self): self.pickle_it(self.ner, "ner") self.pickle_it(self.zero_shot_classifier, "zero_shot_classifier_6") self.pickle_it(self.ner_dates, "ner_dates") self.pickle_it(self.tagger, "pos_tagger_fast") def load_pickled_models(self): ner_dates = self.unpickle_it('ner_dates') ner = self.unpickle_it('ner') zero_shot_classifier = self.unpickle_it('zero_shot_classifier_6') tagger = self.unpickle_it("pos_tagger_fast") return ner_dates, ner, zero_shot_classifier, tagger def get_flair_sentence(self, sent): return Sentence(sent)