Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from textwrap import dedent | |
from huggingface_hub import login | |
import os | |
from dotenv import load_dotenv | |
# load_dotenv() | |
# login( | |
# token=os.environ["HF_TOKEN"], | |
# ) | |
MODEL_LIST = [ | |
"EmergentMethods/Phi-3-mini-4k-instruct-graph", | |
"EmergentMethods/Phi-3-mini-128k-instruct-graph", | |
"EmergentMethods/Phi-3-medium-128k-instruct-graph" | |
] | |
torch.random.manual_seed(0) | |
class Phi3InstructGraph: | |
def __init__(self, model = "EmergentMethods/Phi-3-mini-4k-instruct-graph"): | |
if model not in MODEL_LIST: | |
raise ValueError(f"model must be one of {MODEL_LIST}") | |
self.model_path = model | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_path, | |
device_map="cuda", | |
torch_dtype="auto", | |
trust_remote_code=True, | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | |
self.pipe = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
) | |
def _generate(self, messages): | |
generation_args = { | |
"max_new_tokens": 2000, | |
"return_full_text": False, | |
"temperature": 0.0, | |
"do_sample": False, | |
} | |
return self.pipe(messages, **generation_args) | |
def _get_messages(self, text): | |
messages = [ | |
{ | |
"role": "system", | |
"content": dedent("""\n | |
A chat between a curious user and an artificial intelligence Assistant. The Assistant is an expert at identifying entities and relationships in text. The Assistant responds in JSON output only. | |
The User provides text in the format: | |
-------Text begin------- | |
<User provided text> | |
-------Text end------- | |
The Assistant follows the following steps before replying to the User: | |
1. **identify the most important entities** The Assistant identifies the most important entities in the text. These entities are listed in the JSON output under the key "nodes", they follow the structure of a list of dictionaries where each dict is: | |
"nodes":[{"id": <entity N>, "type": <type>, "detailed_type": <detailed type>}, ...] | |
where "type": <type> is a broad categorization of the entity. "detailed type": <detailed_type> is a very descriptive categorization of the entity. | |
2. **determine relationships** The Assistant uses the text between -------Text begin------- and -------Text end------- to determine the relationships between the entities identified in the "nodes" list defined above. These relationships are called "edges" and they follow the structure of: | |
"edges":[{"from": <entity 1>, "to": <entity 2>, "label": <relationship>}, ...] | |
The <entity N> must correspond to the "id" of an entity in the "nodes" list. | |
The Assistant never repeats the same node twice. The Assistant never repeats the same edge twice. | |
The Assistant responds to the User in JSON only, according to the following JSON schema: | |
{"type":"object","properties":{"nodes":{"type":"array","items":{"type":"object","properties":{"id":{"type":"string"},"type":{"type":"string"},"detailed_type":{"type":"string"}},"required":["id","type","detailed_type"],"additionalProperties":false}},"edges":{"type":"array","items":{"type":"object","properties":{"from":{"type":"string"},"to":{"type":"string"},"label":{"type":"string"}},"required":["from","to","label"],"additionalProperties":false}}},"required":["nodes","edges"],"additionalProperties":false} | |
""") | |
}, | |
{ | |
"role": "user", | |
"content": dedent(f"""\n | |
-------Text begin------- | |
{text} | |
-------Text end------- | |
""") | |
} | |
] | |
return messages | |
def extract(self, text): | |
messages = self._get_messages(text) | |
pipe_output = self._generate(messages) | |
return pipe_output[0]["generated_text"] | |