"""Utilities for the Fujitsu-LLM-KG-8x7B models. """ from typing import Literal, Sequence, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer ############################################################################### # Generation ############################################################################### class Fujitsu_LLM_KG: """The Fujitsu-LLM-KG-8x7B model. """ def __init__(self, model_id: str, *, device_map: str = "auto") -> None: """Initializes the model and tokenizer. """ self.model = AutoModelForCausalLM.from_pretrained( model_id, device_map=device_map, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") self.tokenizer.pad_token = self.tokenizer.eos_token def generate(self, prompt:str, *, max_new_tokens: int = 2048, num_beams: int = 1, ) -> str: """Generate an answer. """ tokenized = self.tokenizer(prompt, return_tensors="pt", padding=True) with torch.no_grad(): outputs = self.model.generate( tokenized["input_ids"].to("cuda"), attention_mask=tokenized["attention_mask"].to("cuda"), pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=max_new_tokens, do_sample=False, num_beams=num_beams, ) answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):] return answer ############################################################################### # Extraction ############################################################################### def extract_turtle(text: str, *, with_rationale = False) -> str: """Extracts the RDF Turtle part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model. """ TOKENS = ["<", "rel:", "rdf:", "]"] if with_rationale: TOKENS.append("#@") turtle = "" for line in text.splitlines(): line_ = line.strip() if line == "" or any(line_.startswith(c) for c in TOKENS): if turtle: turtle += "\n" turtle += line return turtle def extract_answer(text: str) -> Tuple[str, Sequence[str]]: """Extracts the final answer part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model. """ path = [] answer = "" state: Literal["path", "answer"] = "path" for line in text.splitlines(): if line.strip() and "```" not in line and "## " not in line: if state == "path": path.append(line) elif state == "answer": if answer: answer += "\n" answer += line if "## Explore Path" in line: state = "path" path = [] elif "## Answer" in line: state = "answer" answer = "" elif "```" in line and answer: break path = tuple(p.strip() for p in path) answer = answer.strip() return answer, path