Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| from .Model import Model | |
| import tiktoken | |
| from transformers import AutoTokenizer | |
| import time | |
| import google.generativeai as genai | |
| class Gemini(Model): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| api_keys = config["api_key_info"]["api_keys"] | |
| api_pos = int(config["api_key_info"]["api_key_use"]) | |
| assert (0 <= api_pos < len(api_keys)), "Please enter a valid API key to use" | |
| self.max_output_tokens = int(config["params"]["max_output_tokens"]) | |
| genai.configure(api_key=api_keys[api_pos]) | |
| # Map the model name to a valid Gemini model | |
| self.model = genai.GenerativeModel(self.name) | |
| self.llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | |
| self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
| self.seed = 10 | |
| def query(self, msg, max_tokens=128000): | |
| super().query(max_tokens) | |
| while True: | |
| try: | |
| generation_config = genai.types.GenerationConfig( | |
| temperature=self.temperature, | |
| max_output_tokens=self.max_output_tokens, | |
| candidate_count=1 | |
| ) | |
| response = self.model.generate_content( | |
| contents=msg, | |
| generation_config=generation_config | |
| ) | |
| # Check if response was blocked by safety filters | |
| if response.candidates and response.candidates[0].finish_reason == 2: | |
| blocked_filter = response.prompt_feedback.safety_ratings[0].category | |
| print(f"Warning: Response was blocked by {blocked_filter} safety filter. Retrying with different prompt...") | |
| continue | |
| if not response.text: | |
| raise ValueError("Empty response from Gemini API") | |
| time.sleep(1) | |
| break | |
| except Exception as e: | |
| print(f"Error in Gemini API call: {str(e)}") | |
| time.sleep(100) | |
| return response.text | |
| def get_prompt_length(self,msg): | |
| encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
| num_tokens = len(encoding.encode(msg)) | |
| return num_tokens | |
| def cut_context(self,msg,max_length): | |
| tokens = self.encoding.encode(msg) | |
| truncated_tokens = tokens[:max_length] | |
| truncated_text = self.encoding.decode(truncated_tokens) | |
| return truncated_text | |