vilarin commited on
Commit
4b878db
1 Parent(s): 0ec7560

Update app/webui/patch.py

Browse files
Files changed (1) hide show
  1. app/webui/patch.py +168 -163
app/webui/patch.py CHANGED
@@ -1,164 +1,169 @@
1
- # a monkey patch to use llama-index completion
2
- import os
3
- import time
4
- from functools import wraps
5
- from threading import Lock
6
- from typing import Union
7
- import src.translation_agent.utils as utils
8
-
9
- from llama_index.llms.groq import Groq
10
- from llama_index.llms.cohere import Cohere
11
- from llama_index.llms.openai import OpenAI
12
- from llama_index.llms.together import TogetherLLM
13
- from llama_index.llms.ollama import Ollama
14
- from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
15
-
16
- from llama_index.core import Settings
17
- from llama_index.core.llms import ChatMessage
18
-
19
- RPM = 60
20
-
21
- # Add your LLMs here
22
- def model_load(
23
- endpoint: str,
24
- model: str,
25
- api_key: str = None,
26
- context_window: int = 4096,
27
- num_output: int = 512,
28
- rpm: int = RPM,
29
- ):
30
- if endpoint == "Groq":
31
- llm = Groq(
32
- model=model,
33
- api_key=api_key if api_key else os.getenv("GROQ_API_KEY"),
34
- )
35
- elif endpoint == "Cohere":
36
- llm = Cohere(
37
- model=model,
38
- api_key=api_key if api_key else os.getenv("COHERE_API_KEY"),
39
- )
40
- elif endpoint == "OpenAI":
41
- llm = OpenAI(
42
- model=model,
43
- api_key=api_key if api_key else os.getenv("OPENAI_API_KEY"),
44
- )
45
- elif endpoint == "TogetherAI":
46
- llm = TogetherLLM(
47
- model=model,
48
- api_key=api_key if api_key else os.getenv("TOGETHER_API_KEY"),
49
- )
50
- elif endpoint == "Ollama":
51
- llm = Ollama(
52
- model=model,
53
- request_timeout=120.0)
54
- elif endpoint == "Huggingface":
55
- llm = HuggingFaceInferenceAPI(
56
- model_name=model,
57
- token=api_key if api_key else os.getenv("HF_TOKEN"),
58
- task="text-generation",
59
- )
60
-
61
- global RPM
62
- RPM = rpm
63
-
64
- Settings.llm = llm
65
- # maximum input size to the LLM
66
- Settings.context_window = context_window
67
-
68
- # number of tokens reserved for text generation.
69
- Settings.num_output = num_output
70
-
71
- def rate_limit(get_max_per_minute):
72
- def decorator(func):
73
- lock = Lock()
74
- last_called = [0.0]
75
-
76
- @wraps(func)
77
- def wrapper(*args, **kwargs):
78
- with lock:
79
- max_per_minute = get_max_per_minute()
80
- min_interval = 60.0 / max_per_minute
81
- elapsed = time.time() - last_called[0]
82
- left_to_wait = min_interval - elapsed
83
-
84
- if left_to_wait > 0:
85
- time.sleep(left_to_wait)
86
-
87
- ret = func(*args, **kwargs)
88
- last_called[0] = time.time()
89
- return ret
90
- return wrapper
91
- return decorator
92
-
93
- @rate_limit(lambda: RPM)
94
- def get_completion(
95
- prompt: str,
96
- system_message: str = "You are a helpful assistant.",
97
- temperature: float = 0.3,
98
- json_mode: bool = False,
99
- ) -> Union[str, dict]:
100
- """
101
- Generate a completion using the OpenAI API.
102
-
103
- Args:
104
- prompt (str): The user's prompt or query.
105
- system_message (str, optional): The system message to set the context for the assistant.
106
- Defaults to "You are a helpful assistant.".
107
- temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
108
- Defaults to 0.3.
109
- json_mode (bool, optional): Whether to return the response in JSON format.
110
- Defaults to False.
111
-
112
- Returns:
113
- Union[str, dict]: The generated completion.
114
- If json_mode is True, returns the complete API response as a dictionary.
115
- If json_mode is False, returns the generated text as a string.
116
- """
117
- print(time.localtime())
118
- llm = Settings.llm
119
- if llm.class_name() == "HuggingFaceInferenceAPI":
120
- llm.system_prompt = system_message
121
- messages = [
122
- ChatMessage(
123
- role="user", content=prompt),
124
- ]
125
-
126
- response = llm.chat(
127
- messages=messages,
128
- temperature=temperature,
129
- )
130
- return response.message.content
131
- else:
132
- messages = [
133
- ChatMessage(
134
- role="system", content=system_message),
135
- ChatMessage(
136
- role="user", content=prompt),
137
- ]
138
-
139
- if json_mode:
140
- response = llm.chat(
141
- temperature=temperature,
142
- response_format={"type": "json_object"},
143
- messages=messages,
144
- )
145
- return response.message.content
146
- else:
147
- response = llm.chat(
148
- temperature=temperature,
149
- messages=messages,
150
- )
151
- return response.message.content
152
-
153
- utils.get_completion = get_completion
154
-
155
- one_chunk_initial_translation = utils.one_chunk_initial_translation
156
- one_chunk_reflect_on_translation = utils.one_chunk_reflect_on_translation
157
- one_chunk_improve_translation = utils.one_chunk_improve_translation
158
- one_chunk_translate_text = utils.one_chunk_translate_text
159
- num_tokens_in_string = utils.num_tokens_in_string
160
- multichunk_initial_translation = utils.multichunk_initial_translation
161
- multichunk_reflect_on_translation = utils.multichunk_reflect_on_translation
162
- multichunk_improve_translation = utils.multichunk_improve_translation
163
- multichunk_translation = utils.multichunk_translation
 
 
 
 
 
164
  calculate_chunk_size =utils.calculate_chunk_size
 
1
+ # a monkey patch to use llama-index completion
2
+ import os
3
+ import time
4
+ import gradio as gr
5
+ from functools import wraps
6
+ from threading import Lock
7
+ from typing import Union
8
+ import src.translation_agent.utils as utils
9
+
10
+ from llama_index.llms.groq import Groq
11
+ from llama_index.llms.cohere import Cohere
12
+ from llama_index.llms.openai import OpenAI
13
+ from llama_index.llms.together import TogetherLLM
14
+ from llama_index.llms.ollama import Ollama
15
+ from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
16
+
17
+ from llama_index.core import Settings
18
+ from llama_index.core.llms import ChatMessage
19
+
20
+ RPM = 60
21
+
22
+ # Add your LLMs here
23
+ def model_load(
24
+ endpoint: str,
25
+ model: str,
26
+ api_key: str = None,
27
+ context_window: int = 4096,
28
+ num_output: int = 512,
29
+ rpm: int = RPM,
30
+ ):
31
+ if endpoint == "Groq":
32
+ llm = Groq(
33
+ model=model,
34
+ api_key=api_key if api_key else os.getenv("GROQ_API_KEY"),
35
+ )
36
+ elif endpoint == "Cohere":
37
+ llm = Cohere(
38
+ model=model,
39
+ api_key=api_key if api_key else os.getenv("COHERE_API_KEY"),
40
+ )
41
+ elif endpoint == "OpenAI":
42
+ llm = OpenAI(
43
+ model=model,
44
+ api_key=api_key if api_key else os.getenv("OPENAI_API_KEY"),
45
+ )
46
+ elif endpoint == "TogetherAI":
47
+ llm = TogetherLLM(
48
+ model=model,
49
+ api_key=api_key if api_key else os.getenv("TOGETHER_API_KEY"),
50
+ )
51
+ elif endpoint == "Ollama":
52
+ llm = Ollama(
53
+ model=model,
54
+ request_timeout=120.0)
55
+ elif endpoint == "Huggingface":
56
+ llm = HuggingFaceInferenceAPI(
57
+ model_name=model,
58
+ token=api_key if api_key else os.getenv("HF_TOKEN"),
59
+ task="text-generation",
60
+ )
61
+
62
+ global RPM
63
+ RPM = rpm
64
+
65
+ Settings.llm = llm
66
+ # maximum input size to the LLM
67
+ Settings.context_window = context_window
68
+
69
+ # number of tokens reserved for text generation.
70
+ Settings.num_output = num_output
71
+
72
+ def rate_limit(get_max_per_minute):
73
+ def decorator(func):
74
+ lock = Lock()
75
+ last_called = [0.0]
76
+
77
+ @wraps(func)
78
+ def wrapper(*args, **kwargs):
79
+ with lock:
80
+ max_per_minute = get_max_per_minute()
81
+ min_interval = 60.0 / max_per_minute
82
+ elapsed = time.time() - last_called[0]
83
+ left_to_wait = min_interval - elapsed
84
+
85
+ if left_to_wait > 0:
86
+ time.sleep(left_to_wait)
87
+
88
+ ret = func(*args, **kwargs)
89
+ last_called[0] = time.time()
90
+ return ret
91
+ return wrapper
92
+ return decorator
93
+
94
+ @rate_limit(lambda: RPM)
95
+ def get_completion(
96
+ prompt: str,
97
+ system_message: str = "You are a helpful assistant.",
98
+ temperature: float = 0.3,
99
+ json_mode: bool = False,
100
+ ) -> Union[str, dict]:
101
+ """
102
+ Generate a completion using the OpenAI API.
103
+
104
+ Args:
105
+ prompt (str): The user's prompt or query.
106
+ system_message (str, optional): The system message to set the context for the assistant.
107
+ Defaults to "You are a helpful assistant.".
108
+ temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
109
+ Defaults to 0.3.
110
+ json_mode (bool, optional): Whether to return the response in JSON format.
111
+ Defaults to False.
112
+
113
+ Returns:
114
+ Union[str, dict]: The generated completion.
115
+ If json_mode is True, returns the complete API response as a dictionary.
116
+ If json_mode is False, returns the generated text as a string.
117
+ """
118
+ llm = Settings.llm
119
+ if llm.class_name() == "HuggingFaceInferenceAPI":
120
+ llm.system_prompt = system_message
121
+ messages = [
122
+ ChatMessage(
123
+ role="user", content=prompt),
124
+ ]
125
+ try:
126
+ response = llm.chat(
127
+ messages=messages,
128
+ temperature=temperature,
129
+ )
130
+ return response.message.content
131
+ except Exception as e:
132
+ raise gr.Error(f"An unexpected error occurred: {e}")
133
+ else:
134
+ messages = [
135
+ ChatMessage(
136
+ role="system", content=system_message),
137
+ ChatMessage(
138
+ role="user", content=prompt),
139
+ ]
140
+
141
+ if json_mode:
142
+ response = llm.chat(
143
+ temperature=temperature,
144
+ response_format={"type": "json_object"},
145
+ messages=messages,
146
+ )
147
+ return response.message.content
148
+ else:
149
+ try:
150
+ response = llm.chat(
151
+ temperature=temperature,
152
+ messages=messages,
153
+ )
154
+ return response.message.content
155
+ except Exception as e:
156
+ raise gr.Error(f"An unexpected error occurred: {e}")
157
+
158
+ utils.get_completion = get_completion
159
+
160
+ one_chunk_initial_translation = utils.one_chunk_initial_translation
161
+ one_chunk_reflect_on_translation = utils.one_chunk_reflect_on_translation
162
+ one_chunk_improve_translation = utils.one_chunk_improve_translation
163
+ one_chunk_translate_text = utils.one_chunk_translate_text
164
+ num_tokens_in_string = utils.num_tokens_in_string
165
+ multichunk_initial_translation = utils.multichunk_initial_translation
166
+ multichunk_reflect_on_translation = utils.multichunk_reflect_on_translation
167
+ multichunk_improve_translation = utils.multichunk_improve_translation
168
+ multichunk_translation = utils.multichunk_translation
169
  calculate_chunk_size =utils.calculate_chunk_size