cwadayi commited on
Commit
98ed026
·
verified ·
1 Parent(s): d1aea6d

Create src/providers/gemini.py

Browse files
Files changed (1) hide show
  1. src/providers/gemini.py +54 -0
src/providers/gemini.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/providers/gemini.py
2
+ import os
3
+ from typing import List, Dict, Optional
4
+ import google.generativeai as genai
5
+
6
+ class GeminiProvider:
7
+ _inited = False
8
+
9
+ @classmethod
10
+ def _ensure_init(cls):
11
+ if cls._inited:
12
+ return
13
+ api_key = os.getenv("GEMINI_API_KEY")
14
+ if not api_key:
15
+ raise RuntimeError("Missing GEMINI_API_KEY (set it in Spaces Secrets).")
16
+ genai.configure(api_key=api_key)
17
+ cls._inited = True
18
+
19
+ @staticmethod
20
+ def _messages_to_history(messages: List[Dict[str, str]], system_prompt: Optional[str]):
21
+ history = []
22
+ if system_prompt:
23
+ history.append({"role": "user", "parts": system_prompt})
24
+ for m in messages:
25
+ role = m.get("role", "user")
26
+ content = m.get("content", "")
27
+ if role == "assistant":
28
+ history.append({"role": "model", "parts": content})
29
+ else:
30
+ history.append({"role": "user", "parts": content})
31
+ return history
32
+
33
+ @classmethod
34
+ def generate(
35
+ cls,
36
+ model_name: str,
37
+ messages: List[Dict[str, str]],
38
+ system_prompt: Optional[str] = None,
39
+ max_tokens: int = 8192,
40
+ temperature: float = 0.7,
41
+ ) -> str:
42
+ cls._ensure_init()
43
+ model = genai.GenerativeModel(model_name)
44
+ history = cls._messages_to_history(messages, system_prompt)
45
+ chat = model.start_chat(history=history[:-1] if len(history) > 1 else [])
46
+ user_turn = history[-1]["parts"] if history else ""
47
+ resp = chat.send_message(
48
+ user_turn,
49
+ generation_config=genai.types.GenerationConfig(
50
+ max_output_tokens=max_tokens,
51
+ temperature=temperature,
52
+ ),
53
+ )
54
+ return resp.text or ""