Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen | |
| from llama_index.core.llms.callbacks import llm_completion_callback | |
| from typing import Any, Iterator | |
| import torch | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| from pydantic import Field, field_validator | |
| # for transformers 2 | |
| class GemmaLLMInterface(CustomLLM): | |
| def __init__(self, model_id: str = "google/gemma-2b-it", context_window: int = 8192, num_output: int = 2048): | |
| self.model_id = model_id | |
| self.context_window = context_window | |
| self.num_output = num_output | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| self.model.eval() | |
| def _format_prompt(self, message: str) -> str: | |
| return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
| def metadata(self) -> LLMMetadata: | |
| return LLMMetadata( | |
| context_window=self.context_window, | |
| num_output=self.num_output, | |
| model_name=self.model_id, | |
| ) | |
| def _prepare_inputs(self, prompt: str) -> dict: | |
| formatted_prompt = self._format_prompt(prompt) | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True).to(self.model.device) | |
| if inputs["input_ids"].shape[1] > self.context_window: | |
| inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] | |
| return inputs | |
| def _generate(self, inputs: dict) -> Iterator[str]: | |
| for output in self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.num_output, | |
| do_sample=True, | |
| top_p=0.9, | |
| top_k=50, | |
| temperature=0.7, | |
| num_beams=1, | |
| repetition_penalty=1.1, | |
| streamer=None, | |
| return_dict_in_generate=True, | |
| output_scores=False, | |
| ): | |
| new_tokens = output.sequences[:, inputs["input_ids"].shape[-1]:] | |
| yield self.tokenizer.decode(new_tokens[0], skip_special_tokens=True) | |
| def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
| inputs = self._prepare_inputs(prompt) | |
| response = "".join(self._generate(inputs)) | |
| return CompletionResponse(text=response) | |
| def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
| inputs = self._prepare_inputs(prompt) | |
| response = "" | |
| for new_token in self._generate(inputs): | |
| response += new_token | |
| yield CompletionResponse(text=response, delta=new_token) | |
| # for transformers 1 | |
| """class GemmaLLMInterface(CustomLLM): | |
| model: Any | |
| tokenizer: Any | |
| context_window: int = 8192 | |
| num_output: int = 2048 | |
| model_name: str = "gemma_2" | |
| def _format_prompt(self, message: str) -> str: | |
| return ( | |
| f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
| ) | |
| @property | |
| def metadata(self) -> LLMMetadata: | |
| return LLMMetadata( | |
| context_window=self.context_window, | |
| num_output=self.num_output, | |
| model_name=self.model_name, | |
| ) | |
| def _prepare_generation(self, prompt: str) -> tuple: | |
| prompt = self._format_prompt(prompt) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model.to(device) | |
| inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device) | |
| if inputs["input_ids"].shape[1] > self.context_window: | |
| inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:] | |
| streamer = TextIteratorStreamer(self.tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = { | |
| "input_ids": inputs["input_ids"], | |
| "streamer": streamer, | |
| "max_new_tokens": self.num_output, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "top_k": 50, | |
| "temperature": 0.7, | |
| "num_beams": 1, | |
| "repetition_penalty": 1.1, | |
| } | |
| return streamer, generate_kwargs | |
| @llm_completion_callback() | |
| def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
| streamer, generate_kwargs = self._prepare_generation(prompt) | |
| t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| response = "" | |
| for new_token in streamer: | |
| response += new_token | |
| return CompletionResponse(text=response) | |
| @llm_completion_callback() | |
| def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
| streamer, generate_kwargs = self._prepare_generation(prompt) | |
| t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| try: | |
| for new_token in streamer: | |
| yield CompletionResponse(text=new_token) | |
| except StopIteration: | |
| return""" |