paper_qa / gen /gemini_chat.py
chansung's picture
Update gen/gemini_chat.py
644f805 verified
raw
history blame
3.8 kB
import copy
import asyncio
import google.generativeai as genai
from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt
class GeminiChatPromptFmt(PromptFmt):
@classmethod
def ctx(cls, context):
if context is None or context == "":
return None
else:
return {
"role": "system",
"parts": [context]
}
@classmethod
def prompt(cls, pingpong, truncate_size):
ping = pingpong.ping[:truncate_size]
pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
result = [
{
"role": "user",
"parts": [ping]
}
]
if pong != "":
result = result + [
{
"role": "model",
"parts": [pong]
}
]
return result
class GeminiChatPPManager(PPManager):
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=GeminiChatPromptFmt, truncate_size: int=None):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
pingpongs = copy.deepcopy(self.pingpongs)
ctx = fmt.ctx(self.ctx)
ctx = ctx['parts'][0] if ctx is not None else ""
results = []
for idx, pingpong in enumerate(pingpongs[from_idx:to_idx]):
if idx == 0:
pingpong.ping = f"SYSTEM: {ctx} ----------- \n" + pingpong.ping
results += fmt.prompt(pingpong, truncate_size=truncate_size)
return results
class GradioGeminiChatPPManager(GeminiChatPPManager):
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = []
for pingpong in self.pingpongs[from_idx:to_idx]:
results.append(fmt.ui(pingpong))
return results
def init(api_key):
genai.configure(api_key=api_key)
def _default_gen_text():
return {
"temperature": 0.9,
"top_p": 0.8,
"top_k": 32,
"max_output_tokens": 2048,
}
def _default_safety_settings():
return [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
]
async def _word_generator(sentence):
for word in sentence.split(" "):
yield word
delay = 0.03 + (len(word) * 0.005)
await asyncio.sleep(delay) # Simulate a short delay
async def gen_text(
prompts,
gen_config=_default_gen_text(),
safety_settings=_default_safety_settings(),
stream=True
):
model = genai.GenerativeModel(model_name="gemini-1.0-pro",
generation_config=gen_config,
safety_settings=safety_settings)
user_prompt = prompts[-1]
prompts = prompts[:-1]
convo = model.start_chat(history=prompts)
resps = await convo.send_message_async(
user_prompt["parts"][0], stream=stream
)
async for resp in resps:
print(resp.candidates[0].finish_reason.name)
async for word in _word_generator(resp.text):
yield word + " "