Siddhant commited on
Commit
afc5b80
1 Parent(s): 3c5dadd

Delete LLM/mlx_language_model.py

Browse files
Files changed (1) hide show
  1. LLM/mlx_language_model.py +0 -97
LLM/mlx_language_model.py DELETED
@@ -1,97 +0,0 @@
1
- import logging
2
- from LLM.chat import Chat
3
- from baseHandler import BaseHandler
4
- from mlx_lm import load, stream_generate, generate
5
- from rich.console import Console
6
- import torch
7
-
8
- logging.basicConfig(
9
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
10
- )
11
- logger = logging.getLogger(__name__)
12
-
13
- console = Console()
14
-
15
-
16
- class MLXLanguageModelHandler(BaseHandler):
17
- """
18
- Handles the language model part.
19
- """
20
-
21
- def setup(
22
- self,
23
- model_name="microsoft/Phi-3-mini-4k-instruct",
24
- device="mps",
25
- torch_dtype="float16",
26
- gen_kwargs={},
27
- user_role="user",
28
- chat_size=1,
29
- init_chat_role=None,
30
- init_chat_prompt="You are a helpful AI assistant.",
31
- ):
32
- self.model_name = model_name
33
- self.model, self.tokenizer = load(self.model_name)
34
- self.gen_kwargs = gen_kwargs
35
-
36
- self.chat = Chat(chat_size)
37
- if init_chat_role:
38
- if not init_chat_prompt:
39
- raise ValueError(
40
- "An initial promt needs to be specified when setting init_chat_role."
41
- )
42
- self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
43
- self.user_role = user_role
44
-
45
- self.warmup()
46
-
47
- def warmup(self):
48
- logger.info(f"Warming up {self.__class__.__name__}")
49
-
50
- dummy_input_text = "Write me a poem about Machine Learning."
51
- dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
52
-
53
- n_steps = 2
54
-
55
- for _ in range(n_steps):
56
- prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
57
- generate(
58
- self.model,
59
- self.tokenizer,
60
- prompt=prompt,
61
- max_tokens=self.gen_kwargs["max_new_tokens"],
62
- verbose=False,
63
- )
64
-
65
- def process(self, prompt):
66
- logger.debug("infering language model...")
67
-
68
- self.chat.append({"role": self.user_role, "content": prompt})
69
-
70
- # Remove system messages if using a Gemma model
71
- if "gemma" in self.model_name.lower():
72
- chat_messages = [
73
- msg for msg in self.chat.to_list() if msg["role"] != "system"
74
- ]
75
- else:
76
- chat_messages = self.chat.to_list()
77
-
78
- prompt = self.tokenizer.apply_chat_template(
79
- chat_messages, tokenize=False, add_generation_prompt=True
80
- )
81
- output = ""
82
- curr_output = ""
83
- for t in stream_generate(
84
- self.model,
85
- self.tokenizer,
86
- prompt,
87
- max_tokens=self.gen_kwargs["max_new_tokens"],
88
- ):
89
- output += t
90
- curr_output += t
91
- if curr_output.endswith((".", "?", "!", "<|end|>")):
92
- yield curr_output.replace("<|end|>", "")
93
- curr_output = ""
94
- generated_text = output.replace("<|end|>", "")
95
- torch.mps.empty_cache()
96
-
97
- self.chat.append({"role": "assistant", "content": generated_text})