Kang Suhyun commited on
Commit
9e789e7
1 Parent(s): 6d880cd

[#104] Display error message for the context window exceeded error (#105)

Browse files

Changes:
- Added exception handling for `ContextWindowExceededError` in the relevant module.
- Logged the error to capture detailed error information.
- Display an error message to inform users about the error and how to resolve it.

Files changed (2) hide show
  1. model.py +16 -8
  2. response.py +14 -4
model.py CHANGED
@@ -29,6 +29,10 @@ DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the l
29
  DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
30
 
31
 
 
 
 
 
32
  class Model:
33
 
34
  def __init__(
@@ -49,14 +53,18 @@ class Model:
49
  self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
50
 
51
  def completion(self, messages: List, max_tokens: float = None) -> str:
52
- response = litellm.completion(model=self.provider + "/" +
53
- self.name if self.provider else self.name,
54
- api_key=self.api_key,
55
- api_base=self.api_base,
56
- messages=messages,
57
- max_tokens=max_tokens)
58
-
59
- return response.choices[0].message.content
 
 
 
 
60
 
61
 
62
  supported_models: List[Model] = [
 
29
  DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
30
 
31
 
32
+ class ContextWindowExceededError(Exception):
33
+ pass
34
+
35
+
36
  class Model:
37
 
38
  def __init__(
 
53
  self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
54
 
55
  def completion(self, messages: List, max_tokens: float = None) -> str:
56
+ try:
57
+ response = litellm.completion(model=self.provider + "/" +
58
+ self.name if self.provider else self.name,
59
+ api_key=self.api_key,
60
+ api_base=self.api_base,
61
+ messages=messages,
62
+ max_tokens=max_tokens)
63
+
64
+ return response.choices[0].message.content
65
+
66
+ except litellm.ContextWindowExceededError as e:
67
+ raise ContextWindowExceededError() from e
68
 
69
 
70
  supported_models: List[Model] = [
response.py CHANGED
@@ -3,6 +3,7 @@ This module contains functions for generating responses using LLMs.
3
  """
4
 
5
  import enum
 
6
  from random import sample
7
  from typing import List
8
  from uuid import uuid4
@@ -11,9 +12,14 @@ from firebase_admin import firestore
11
  import gradio as gr
12
 
13
  from leaderboard import db
 
14
  from model import Model
15
  from model import supported_models
16
 
 
 
 
 
17
 
18
  def get_history_collection(category: str):
19
  if category == Category.SUMMARIZE.value:
@@ -81,10 +87,14 @@ def get_responses(prompt: str, category: str, source_lang: str,
81
  create_history(category, model.name, instruction, prompt, response)
82
  responses.append(response)
83
 
84
- # TODO(#1): Narrow down the exception type.
85
- except Exception as e: # pylint: disable=broad-except
86
- print(f"Error with model {model.name}: {e}")
87
- raise gr.Error("Failed to get response. Please try again.")
 
 
 
 
88
 
89
  model_names = [model.name for model in models]
90
 
 
3
  """
4
 
5
  import enum
6
+ import logging
7
  from random import sample
8
  from typing import List
9
  from uuid import uuid4
 
12
  import gradio as gr
13
 
14
  from leaderboard import db
15
+ from model import ContextWindowExceededError
16
  from model import Model
17
  from model import supported_models
18
 
19
+ logging.basicConfig()
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.INFO)
22
+
23
 
24
  def get_history_collection(category: str):
25
  if category == Category.SUMMARIZE.value:
 
87
  create_history(category, model.name, instruction, prompt, response)
88
  responses.append(response)
89
 
90
+ except ContextWindowExceededError as e:
91
+ logger.exception("Context window exceeded for model %s.", model.name)
92
+ raise gr.Error(
93
+ "The prompt is too long. Please try again with a shorter prompt."
94
+ ) from e
95
+ except Exception as e:
96
+ logger.exception("Failed to get response from model %s.", model.name)
97
+ raise gr.Error("Failed to get response. Please try again.") from e
98
 
99
  model_names = [model.name for model in models]
100