""" This module contains functions for generating responses using LLMs. """ import enum import json import os from random import sample from uuid import uuid4 from firebase_admin import firestore from google.cloud import secretmanager from google.oauth2 import service_account import gradio as gr from litellm import completion from credentials import get_credentials_json from leaderboard import db GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT") MODELS_SECRET = os.environ.get("MODELS_SECRET") secretmanager_client = secretmanager.SecretManagerServiceClient( credentials=service_account.Credentials.from_service_account_info( get_credentials_json())) models_secret = secretmanager_client.access_secret_version( name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT, MODELS_SECRET, "latest")) decoded_secret = models_secret.payload.data.decode("UTF-8") supported_models = json.loads(decoded_secret) def create_history(model_name: str, instruction: str, prompt: str, response: str): doc_id = uuid4().hex doc = { "id": doc_id, "model": model_name, "instruction": instruction, "prompt": prompt, "response": response, "timestamp": firestore.SERVER_TIMESTAMP } doc_ref = db.collection("arena-history").document(doc_id) doc_ref.set(doc) class Category(enum.Enum): SUMMARIZE = "Summarize" TRANSLATE = "Translate" # TODO(#31): Let the model builders set the instruction. def get_instruction(category, source_lang, target_lang): if category == Category.SUMMARIZE.value: return "Summarize the following text, maintaining the original language of the text in the summary." # pylint: disable=line-too-long if category == Category.TRANSLATE.value: return f"Translate the following text from {source_lang} to {target_lang}." def get_responses(user_prompt, category, source_lang, target_lang): if not category: raise gr.Error("Please select a category.") if category == Category.TRANSLATE.value and (not source_lang or not target_lang): raise gr.Error("Please select source and target languages.") models = sample(list(supported_models), 2) instruction = get_instruction(category, source_lang, target_lang) responses = [] for model in models: model_config = supported_models[model] model_name = model_config[ "provider"] + "/" + model if "provider" in model_config else model api_key = model_config.get("apiKey", None) api_base = model_config.get("apiBase", None) try: # TODO(#1): Allow user to set configuration. response = completion(model=model_name, api_key=api_key, api_base=api_base, messages=[{ "content": instruction, "role": "system" }, { "content": user_prompt, "role": "user" }]) content = response.choices[0].message.content create_history(model, instruction, user_prompt, content) responses.append(content) # TODO(#1): Narrow down the exception type. except Exception as e: # pylint: disable=broad-except print(f"Error with model {model}: {e}") raise gr.Error("Failed to get response. Please try again.") # It simulates concurrent stream response generation. max_response_length = max(len(response) for response in responses) for i in range(max_response_length): yield [response[:i + 1] for response in responses] + models + [instruction] yield responses + models + [instruction]