README / scripts /kgllm_utils.py
Fujitsu-LLM's picture
Add scripts.
23df63d
"""Utilities for the Fujitsu-LLM-KG-8x7B models.
"""
from typing import Literal, Sequence, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
###############################################################################
# Generation
###############################################################################
class Fujitsu_LLM_KG:
"""The Fujitsu-LLM-KG-8x7B model.
"""
def __init__(self, model_id: str, *, device_map: str = "auto") -> None:
"""Initializes the model and tokenizer.
"""
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
self.tokenizer.pad_token = self.tokenizer.eos_token
def generate(self, prompt:str,
*,
max_new_tokens: int = 2048,
num_beams: int = 1,
) -> str:
"""Generate an answer.
"""
tokenized = self.tokenizer(prompt, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = self.model.generate(
tokenized["input_ids"].to("cuda"),
attention_mask=tokenized["attention_mask"].to("cuda"),
pad_token_id=self.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=num_beams,
)
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):]
return answer
###############################################################################
# Extraction
###############################################################################
def extract_turtle(text: str, *, with_rationale = False) -> str:
"""Extracts the RDF Turtle part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model.
"""
TOKENS = ["<", "rel:", "rdf:", "]"]
if with_rationale:
TOKENS.append("#@")
turtle = ""
for line in text.splitlines():
line_ = line.strip()
if line == "" or any(line_.startswith(c) for c in TOKENS):
if turtle:
turtle += "\n"
turtle += line
return turtle
def extract_answer(text: str) -> Tuple[str, Sequence[str]]:
"""Extracts the final answer part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model.
"""
path = []
answer = ""
state: Literal["path", "answer"] = "path"
for line in text.splitlines():
if line.strip() and "```" not in line and "## " not in line:
if state == "path":
path.append(line)
elif state == "answer":
if answer:
answer += "\n"
answer += line
if "## Explore Path" in line:
state = "path"
path = []
elif "## Answer" in line:
state = "answer"
answer = ""
elif "```" in line and answer:
break
path = tuple(p.strip() for p in path)
answer = answer.strip()
return answer, path