Spaces:
Sleeping
Sleeping
Delete LLM/mlx_language_model.py
Browse files- LLM/mlx_language_model.py +0 -97
LLM/mlx_language_model.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from LLM.chat import Chat
|
3 |
-
from baseHandler import BaseHandler
|
4 |
-
from mlx_lm import load, stream_generate, generate
|
5 |
-
from rich.console import Console
|
6 |
-
import torch
|
7 |
-
|
8 |
-
logging.basicConfig(
|
9 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
10 |
-
)
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
console = Console()
|
14 |
-
|
15 |
-
|
16 |
-
class MLXLanguageModelHandler(BaseHandler):
|
17 |
-
"""
|
18 |
-
Handles the language model part.
|
19 |
-
"""
|
20 |
-
|
21 |
-
def setup(
|
22 |
-
self,
|
23 |
-
model_name="microsoft/Phi-3-mini-4k-instruct",
|
24 |
-
device="mps",
|
25 |
-
torch_dtype="float16",
|
26 |
-
gen_kwargs={},
|
27 |
-
user_role="user",
|
28 |
-
chat_size=1,
|
29 |
-
init_chat_role=None,
|
30 |
-
init_chat_prompt="You are a helpful AI assistant.",
|
31 |
-
):
|
32 |
-
self.model_name = model_name
|
33 |
-
self.model, self.tokenizer = load(self.model_name)
|
34 |
-
self.gen_kwargs = gen_kwargs
|
35 |
-
|
36 |
-
self.chat = Chat(chat_size)
|
37 |
-
if init_chat_role:
|
38 |
-
if not init_chat_prompt:
|
39 |
-
raise ValueError(
|
40 |
-
"An initial promt needs to be specified when setting init_chat_role."
|
41 |
-
)
|
42 |
-
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
|
43 |
-
self.user_role = user_role
|
44 |
-
|
45 |
-
self.warmup()
|
46 |
-
|
47 |
-
def warmup(self):
|
48 |
-
logger.info(f"Warming up {self.__class__.__name__}")
|
49 |
-
|
50 |
-
dummy_input_text = "Write me a poem about Machine Learning."
|
51 |
-
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
|
52 |
-
|
53 |
-
n_steps = 2
|
54 |
-
|
55 |
-
for _ in range(n_steps):
|
56 |
-
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
|
57 |
-
generate(
|
58 |
-
self.model,
|
59 |
-
self.tokenizer,
|
60 |
-
prompt=prompt,
|
61 |
-
max_tokens=self.gen_kwargs["max_new_tokens"],
|
62 |
-
verbose=False,
|
63 |
-
)
|
64 |
-
|
65 |
-
def process(self, prompt):
|
66 |
-
logger.debug("infering language model...")
|
67 |
-
|
68 |
-
self.chat.append({"role": self.user_role, "content": prompt})
|
69 |
-
|
70 |
-
# Remove system messages if using a Gemma model
|
71 |
-
if "gemma" in self.model_name.lower():
|
72 |
-
chat_messages = [
|
73 |
-
msg for msg in self.chat.to_list() if msg["role"] != "system"
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
chat_messages = self.chat.to_list()
|
77 |
-
|
78 |
-
prompt = self.tokenizer.apply_chat_template(
|
79 |
-
chat_messages, tokenize=False, add_generation_prompt=True
|
80 |
-
)
|
81 |
-
output = ""
|
82 |
-
curr_output = ""
|
83 |
-
for t in stream_generate(
|
84 |
-
self.model,
|
85 |
-
self.tokenizer,
|
86 |
-
prompt,
|
87 |
-
max_tokens=self.gen_kwargs["max_new_tokens"],
|
88 |
-
):
|
89 |
-
output += t
|
90 |
-
curr_output += t
|
91 |
-
if curr_output.endswith((".", "?", "!", "<|end|>")):
|
92 |
-
yield curr_output.replace("<|end|>", "")
|
93 |
-
curr_output = ""
|
94 |
-
generated_text = output.replace("<|end|>", "")
|
95 |
-
torch.mps.empty_cache()
|
96 |
-
|
97 |
-
self.chat.append({"role": "assistant", "content": generated_text})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|