# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. from typing import List import torch from llama.tokenizer import Tokenizer from llama.model import Transformer class LLaMA: def __init__(self, model: Transformer, tokenizer: Tokenizer, vision_model = None): self.model = model self.tokenizer = tokenizer self.vision_model = vision_model def generate( self, prompts: List[str], imgs = None, max_gen_len: int = 512, temperature: float = 0.8, top_p: float = 0.95, ) -> List[str]: bsz = len(prompts) params = self.model.params assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) mode = 'instruct' vision_tokens = None if imgs is not None and self.vision_model is not None: vision_tokens = self.vision_model(imgs) mode = 'caption' prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] min_prompt_size = min([len(t) for t in prompt_tokens]) max_prompt_size = max([len(t) for t in prompt_tokens]) total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id start_pos = min_prompt_size prev_pos = 0 for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, vision_tokens, mode) if temperature > 0: probs = torch.softmax(logits / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token prev_pos = cur_pos decoded = [] for i, t in enumerate(tokens.tolist()): # cut to max gen len t = t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len] # cut to eos tok if any try: t = t[: t.index(self.tokenizer.eos_id)] except ValueError: pass decoded.append(self.tokenizer.decode(t)) return decoded def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token