Spaces:
Running
Running
from typing import Any | |
import google.generativeai as genai | |
from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
from llama_index.core.llms import ( | |
CustomLLM, | |
CompletionResponse, | |
CompletionResponseGen, | |
LLMMetadata, | |
) | |
from llama_index.core.llms.callbacks import llm_completion_callback | |
class GLLM(CustomLLM): | |
def __init__( | |
self, | |
context_window: int = 32768, | |
num_output: int = 4098, | |
model_name: str = "gemini-1.5-flash", | |
system_instruction: str = None, | |
**kwargs: Any, | |
) -> None: | |
super().__init__(**kwargs) | |
self._context_window = context_window | |
self._num_output = num_output | |
self._model_name = model_name | |
self._model = genai.GenerativeModel(model_name, system_instruction=system_instruction) | |
def gai_generate_content(self, prompt: str, temperature:float =0.5) -> str: | |
return self._model.generate_content( | |
prompt, | |
generation_config = genai.GenerationConfig( | |
temperature=temperature, | |
), | |
safety_settings={ | |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
} | |
).text | |
def metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
return LLMMetadata( | |
context_window=self._context_window, | |
num_output=self._num_output, | |
model_name=self._model_name, | |
) | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
text = self.gai_generate_content(prompt) | |
return CompletionResponse(text=text) | |
def stream_complete( | |
self, prompt: str, **kwargs: Any | |
) -> CompletionResponseGen: | |
text = self.gai_generate_content(prompt) | |
response = "" | |
for token in text: | |
response += token | |
yield CompletionResponse(text=response, delta=token) |