Julian-Hans commited on
Commit
714b9dd
1 Parent(s): 52d92f0

changed inference provider for phi3 to lama_cpp

Browse files
Files changed (3) hide show
  1. .gitignore +4 -1
  2. phi3_mini_4k_instruct.py +15 -4
  3. requirements.txt +1 -0
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  /__pycache__
2
- *.wav
 
 
 
 
1
  /__pycache__
2
+ *.wav
3
+ .pytest_cache
4
+ /audio_data
5
+ .cache
phi3_mini_4k_instruct.py CHANGED
@@ -1,9 +1,10 @@
1
  # external imports
2
  from transformers import pipeline
3
  from huggingface_hub import InferenceClient
4
-
5
  # local imports
6
  import config
 
7
 
8
 
9
  class Phi3_Mini_4k_Instruct:
@@ -12,12 +13,22 @@ class Phi3_Mini_4k_Instruct:
12
 
13
  def generate_text(self, messages, use_local_llm):
14
  if use_local_llm:
15
- return self.generate_text_local_pipeline(messages)
16
  else:
17
  return self.generate_text_api(messages)
18
-
 
 
 
 
 
 
 
 
 
 
19
  def generate_text_local_pipeline(self, messages):
20
- self.local_pipeline = pipeline("text-generation", model=config.LLM_MODEL, trust_remote_code=True)
21
  self.local_pipeline.model.config.max_length = config.LLM_MAX_LENGTH
22
  self.local_pipeline.model.config.max_new_tokens = config.LLM_MAX_NEW_TOKENS
23
  self.local_pipeline.model.config.temperature = config.LLM_TEMPERATURE
 
1
  # external imports
2
  from transformers import pipeline
3
  from huggingface_hub import InferenceClient
4
+ import torch
5
  # local imports
6
  import config
7
+ from llama_cpp import Llama
8
 
9
 
10
  class Phi3_Mini_4k_Instruct:
 
13
 
14
  def generate_text(self, messages, use_local_llm):
15
  if use_local_llm:
16
+ return self.generate_text_llama_cpp(messages)
17
  else:
18
  return self.generate_text_api(messages)
19
+
20
+ def generate_text_llama_cpp(self, messages):
21
+ model = Llama.from_pretrained(
22
+ repo_id="microsoft/Phi-3-mini-4k-instruct-gguf",
23
+ filename="Phi-3-mini-4k-instruct-q4.gguf"
24
+ )
25
+ response = model.create_chat_completion(messages)
26
+ generated_message = response['choices'][0]['message']['content']
27
+
28
+ return generated_message
29
+
30
  def generate_text_local_pipeline(self, messages):
31
+ self.local_pipeline = pipeline("text-generation", model=config.LLM_MODEL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
32
  self.local_pipeline.model.config.max_length = config.LLM_MAX_LENGTH
33
  self.local_pipeline.model.config.max_new_tokens = config.LLM_MAX_NEW_TOKENS
34
  self.local_pipeline.model.config.temperature = config.LLM_TEMPERATURE
requirements.txt CHANGED
@@ -23,6 +23,7 @@ importlib_resources==6.4.5
23
  iniconfig==2.0.0
24
  Jinja2==3.1.4
25
  kiwisolver==1.4.7
 
26
  markdown-it-py==3.0.0
27
  MarkupSafe==2.1.5
28
  matplotlib==3.9.2
 
23
  iniconfig==2.0.0
24
  Jinja2==3.1.4
25
  kiwisolver==1.4.7
26
+ llama_cpp_python==0.3.1
27
  markdown-it-py==3.0.0
28
  MarkupSafe==2.1.5
29
  matplotlib==3.9.2