File size: 3,749 Bytes
3c495cc
 
 
 
 
b695eaf
 
3c495cc
17b1403
3c495cc
17b1403
b695eaf
50a5912
3c495cc
 
 
50a5912
17b1403
50a5912
b695eaf
 
 
50a5912
 
 
 
 
 
b695eaf
 
 
3c495cc
 
17b1403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c495cc
 
 
 
 
e4bf3a0
3c495cc
 
a44d1c3
3c495cc
 
 
 
 
 
 
 
 
 
 
 
b695eaf
3c495cc
 
25a095e
3c495cc
b695eaf
 
 
 
 
 
 
3c495cc
 
b695eaf
 
 
3c495cc
 
 
 
 
 
25a095e
17b1403
 
 
3c495cc
 
 
 
 
 
25a095e
 
 
3a85228
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
This module contains functions for generating responses using LLMs.
"""

import enum
import json
import os
from random import sample
from uuid import uuid4

from firebase_admin import firestore
from google.cloud import secretmanager
from google.oauth2 import service_account
import gradio as gr
from litellm import completion

from credentials import get_credentials_json
from leaderboard import db

GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")

secretmanager_client = secretmanager.SecretManagerServiceClient(
    credentials=service_account.Credentials.from_service_account_info(
        get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
    name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
                                                  MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")

supported_models = json.loads(decoded_secret)


def create_history(model_name: str, instruction: str, prompt: str,
                   response: str):
  doc_id = uuid4().hex

  doc = {
      "id": doc_id,
      "model": model_name,
      "instruction": instruction,
      "prompt": prompt,
      "response": response,
      "timestamp": firestore.SERVER_TIMESTAMP
  }

  doc_ref = db.collection("arena-history").document(doc_id)
  doc_ref.set(doc)


class Category(enum.Enum):
  SUMMARIZE = "Summarize"
  TRANSLATE = "Translate"


# TODO(#31): Let the model builders set the instruction.
def get_instruction(category, source_lang, target_lang):
  if category == Category.SUMMARIZE.value:
    return "Summarize the following text, maintaining the original language of the text in the summary."  # pylint: disable=line-too-long
  if category == Category.TRANSLATE.value:
    return f"Translate the following text from {source_lang} to {target_lang}."


def get_responses(user_prompt, category, source_lang, target_lang):
  if not category:
    raise gr.Error("Please select a category.")

  if category == Category.TRANSLATE.value and (not source_lang or
                                               not target_lang):
    raise gr.Error("Please select source and target languages.")

  models = sample(list(supported_models), 2)
  instruction = get_instruction(category, source_lang, target_lang)

  responses = []
  for model in models:
    model_config = supported_models[model]

    model_name = model_config[
        "provider"] + "/" + model if "provider" in model_config else model
    api_key = model_config.get("apiKey", None)
    api_base = model_config.get("apiBase", None)

    try:
      # TODO(#1): Allow user to set configuration.
      response = completion(model=model_name,
                            api_key=api_key,
                            api_base=api_base,
                            messages=[{
                                "content": instruction,
                                "role": "system"
                            }, {
                                "content": user_prompt,
                                "role": "user"
                            }])
      content = response.choices[0].message.content
      create_history(model, instruction, user_prompt, content)
      responses.append(content)

    # TODO(#1): Narrow down the exception type.
    except Exception as e:  # pylint: disable=broad-except
      print(f"Error in bot_response: {e}")
      raise e

  # It simulates concurrent stream response generation.
  max_response_length = max(len(response) for response in responses)
  for i in range(max_response_length):
    yield [response[:i + 1] for response in responses] + models + [instruction]

  yield responses + models + [instruction]