NCTCMumbai commited on
Commit
9ef7776
1 Parent(s): 1705054

Update backend/query_llm.py

Browse files
Files changed (1) hide show
  1. backend/query_llm.py +12 -8
backend/query_llm.py CHANGED
@@ -9,21 +9,25 @@ from typing import Any, Dict, Generator, List
9
  from huggingface_hub import InferenceClient
10
  from transformers import AutoTokenizer
11
 
12
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
-
14
- temperature = 0.9
15
- top_p = 0.6
16
  repetition_penalty = 1.2
17
 
 
18
  OPENAI_KEY = getenv("OPENAI_API_KEY")
19
  HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
20
 
 
 
 
 
 
21
  hf_client = InferenceClient(
22
- "mistralai/Mistral-7B-Instruct-v0.1",
23
  token=HF_TOKEN
24
  )
25
-
26
-
27
  def format_prompt(message: str, api_kind: str):
28
  """
29
  Formats the given message using a chat template.
@@ -46,7 +50,7 @@ def format_prompt(message: str, api_kind: str):
46
  raise ValueError("API is not supported")
47
 
48
 
49
- def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
50
  top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
51
  """
52
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
 
9
  from huggingface_hub import InferenceClient
10
  from transformers import AutoTokenizer
11
 
12
+ #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
14
+ temperature = 0.4
15
+ #top_p = 0.6
16
  repetition_penalty = 1.2
17
 
18
+
19
  OPENAI_KEY = getenv("OPENAI_API_KEY")
20
  HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
21
 
22
+ # hf_client = InferenceClient(
23
+ # "mistralai/Mistral-7B-Instruct-v0.1",
24
+ # token=HF_TOKEN
25
+ # )
26
+
27
  hf_client = InferenceClient(
28
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
29
  token=HF_TOKEN
30
  )
 
 
31
  def format_prompt(message: str, api_kind: str):
32
  """
33
  Formats the given message using a chat template.
 
50
  raise ValueError("API is not supported")
51
 
52
 
53
+ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 4000,
54
  top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
55
  """
56
  Generate a sequence of tokens based on a given prompt and history using Mistral client.