Spaces:
Running
Running
"""Utilities for the Fujitsu-LLM-KG-8x7B models. | |
""" | |
from typing import Literal, Sequence, Tuple | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
############################################################################### | |
# Generation | |
############################################################################### | |
class Fujitsu_LLM_KG: | |
"""The Fujitsu-LLM-KG-8x7B model. | |
""" | |
def __init__(self, model_id: str, *, device_map: str = "auto") -> None: | |
"""Initializes the model and tokenizer. | |
""" | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map=device_map, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
def generate(self, prompt:str, | |
*, | |
max_new_tokens: int = 2048, | |
num_beams: int = 1, | |
) -> str: | |
"""Generate an answer. | |
""" | |
tokenized = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
tokenized["input_ids"].to("cuda"), | |
attention_mask=tokenized["attention_mask"].to("cuda"), | |
pad_token_id=self.tokenizer.eos_token_id, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
num_beams=num_beams, | |
) | |
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):] | |
return answer | |
############################################################################### | |
# Extraction | |
############################################################################### | |
def extract_turtle(text: str, *, with_rationale = False) -> str: | |
"""Extracts the RDF Turtle part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model. | |
""" | |
TOKENS = ["<", "rel:", "rdf:", "]"] | |
if with_rationale: | |
TOKENS.append("#@") | |
turtle = "" | |
for line in text.splitlines(): | |
line_ = line.strip() | |
if line == "" or any(line_.startswith(c) for c in TOKENS): | |
if turtle: | |
turtle += "\n" | |
turtle += line | |
return turtle | |
def extract_answer(text: str) -> Tuple[str, Sequence[str]]: | |
"""Extracts the final answer part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model. | |
""" | |
path = [] | |
answer = "" | |
state: Literal["path", "answer"] = "path" | |
for line in text.splitlines(): | |
if line.strip() and "```" not in line and "## " not in line: | |
if state == "path": | |
path.append(line) | |
elif state == "answer": | |
if answer: | |
answer += "\n" | |
answer += line | |
if "## Explore Path" in line: | |
state = "path" | |
path = [] | |
elif "## Answer" in line: | |
state = "answer" | |
answer = "" | |
elif "```" in line and answer: | |
break | |
path = tuple(p.strip() for p in path) | |
answer = answer.strip() | |
return answer, path | |