Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen | |
| from llama_index.core.llms.callbacks import llm_completion_callback | |
| from typing import Any, Iterator | |
| import torch | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| from pydantic import Field, field_validator | |
| # for transformers 2 (__setattr__ is used to bypass Pydantic check ) | |
| class GemmaLLMInterface(CustomLLM): | |
| def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs): | |
| super().__init__(**kwargs) | |
| object.__setattr__(self, "model_id", model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| object.__setattr__(self, "model", model) | |
| object.__setattr__(self, "tokenizer", tokenizer) | |
| object.__setattr__(self, "context_window", 8192) | |
| object.__setattr__(self, "num_output", 2048) | |
| def _format_prompt(self, message: str) -> str: | |
| return ( | |
| f"<start_of_turn>user\n{message}<end_of_turn>\n" | |
| f"<start_of_turn>model\n" | |
| ) | |
| def metadata(self) -> LLMMetadata: | |
| return LLMMetadata( | |
| context_window=self.context_window, | |
| num_output=self.num_output, | |
| model_name=self.model_id, | |
| ) | |
| def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
| prompt = self._format_prompt(prompt) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| outputs = self.model.generate(**inputs, max_new_tokens=self.num_output) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| response = response[len(prompt):].strip() | |
| return CompletionResponse(text=response if response else "No response generated.") | |
| def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
| #prompt = self._format_prompt(prompt) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) | |
| generation_kwargs = dict(inputs, max_new_tokens=self.num_output, streamer=streamer) | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| streamed_response = "" | |
| for new_text in streamer: | |
| if new_text: | |
| streamed_response += new_text | |
| yield CompletionResponse(text=streamed_response, delta=new_text) | |
| if not streamed_response: | |
| yield CompletionResponse(text="No response generated.", delta="No response generated.") | |