khrek's picture
Update models.py
334c45a
raw
history blame
2.07 kB
import torch
import sentencepiece
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain import PromptTemplate, LLMChain, HuggingFacePipeline
import ast
class Models():
def __init__(self) -> None:
self.template = """
A virtual assistant answers questions from a user based on the provided text.
USER: Text: {input_text}
ASSISTANT: I’ve read this text.
USER: What describes {entity_type} in the text?
ASSISTANT:
"""
self.load_trained_models()
def load_trained_models(self):
#is it best to keep in memory why not pickle?
checkpoint = "Universal-NER/UniNER-7B-all"
ner_model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float32, offload_folder="offload", offload_state_dict = True)
tokenizer = AutoTokenizer.from_pretrained("Universal-NER/UniNER-7B-all", use_fast=False, padding="max_length")
hf_pipeline = pipeline(
"text-generation", #task
model=ner_model,
max_length=1000,
tokenizer=tokenizer,
trust_remote_code=True,
do_sample=True,
top_k=10,
num_return_sequences=1
)
self.llm = HuggingFacePipeline(pipeline = hf_pipeline, model_kwargs = {'temperature':0})
self.prompt = PromptTemplate(template=self.template, input_variables=["input_text","entity_type"])
self.llm_chain = LLMChain(prompt=self.prompt, llm=self.llm)
def extract_ner(self, context, entity_type):
return ast.literal_eval(self.llm_chain.run({"input_text":context,"entity_type":entity_type}))
def get_ner(self, clean_lines, entity):
tokens = []
try_num = 0
while try_num < 5 and tokens == []:
tokens = self.extract_ner(' '.join(clean_lines), entity)
if len(tokens) == 0:
raise ValueError("Couldnt extract {entity}")
return tokens