z00mP commited on
Commit
b3c801a
1 Parent(s): 0b76a3e

add different llms

Browse files
Files changed (2) hide show
  1. app.py +7 -7
  2. backend/query_llm.py +13 -5
app.py CHANGED
@@ -63,25 +63,25 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encode
63
  prompt_html = template_html.render(documents=documents, query=query)
64
 
65
  if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
66
- pass
67
  if llm_model == "mistralai/Mistral-7B-v0.1":
68
- pass
69
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
70
- pass
71
  if llm_model == "gpt-3.5-turbo":
72
- pass
73
  if llm_model == "gpt-4-turbo-preview":
74
- pass
75
 
76
  #if api_kind == "HuggingFace":
77
  # generate_fn = generate_hf
78
  #elif api_kind == "OpenAI":
79
  # generate_fn = generate_openai
80
  #else:
81
- # raise gr.Error(f"API {api_kind} is not supported")
82
 
83
  history[-1][1] = ""
84
- for character in generate_fn(prompt, history[:-1]):
85
  history[-1][1] = character
86
  yield history, prompt_html
87
 
 
63
  prompt_html = template_html.render(documents=documents, query=query)
64
 
65
  if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
66
+ generate_fn = generate_hf
67
  if llm_model == "mistralai/Mistral-7B-v0.1":
68
+ generate_fn = generate_hf
69
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
70
+ generate_fn = generate_hf
71
  if llm_model == "gpt-3.5-turbo":
72
+ generate_fn = generate_openai
73
  if llm_model == "gpt-4-turbo-preview":
74
+ generate_fn = generate_openai
75
 
76
  #if api_kind == "HuggingFace":
77
  # generate_fn = generate_hf
78
  #elif api_kind == "OpenAI":
79
  # generate_fn = generate_openai
80
  #else:
81
+ raise gr.Error(f"API {api_kind} is not supported")
82
 
83
  history[-1][1] = ""
84
+ for character in generate_fn(prompt, history[:-1], llm_model):
85
  history[-1][1] = character
86
  yield history, prompt_html
87
 
backend/query_llm.py CHANGED
@@ -34,7 +34,7 @@ OAI_GENERATE_KWARGS = {
34
  }
35
 
36
 
37
- def format_prompt(message: str, api_kind: str):
38
  """
39
  Formats the given message using a chat template.
40
 
@@ -51,12 +51,13 @@ def format_prompt(message: str, api_kind: str):
51
  if api_kind == "openai":
52
  return messages
53
  elif api_kind == "hf":
 
54
  return TOKENIZER.apply_chat_template(messages, tokenize=False)
55
  elif api_kind:
56
  raise ValueError("API is not supported")
57
 
58
 
59
- def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
60
  """
61
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
62
 
@@ -67,8 +68,14 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
67
  Generator[str, None, str]: A generator yielding chunks of generated text.
68
  Returns a final string if an error occurs.
69
  """
 
70
 
71
- formatted_prompt = format_prompt(prompt, "hf")
 
 
 
 
 
72
  formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
73
 
74
  try:
@@ -93,7 +100,7 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
93
  raise gr.Error(f"Unhandled Exception: {str(e)}")
94
 
95
 
96
- def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
97
  """
98
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
99
 
@@ -108,7 +115,8 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
108
 
109
  try:
110
  stream = OAI_CLIENT.chat.completions.create(
111
- model=os.getenv("OPENAI_MODEL"),
 
112
  messages=formatted_prompt,
113
  **OAI_GENERATE_KWARGS,
114
  stream=True
 
34
  }
35
 
36
 
37
+ def format_prompt(message: str, api_kind: str, tokenizer_name = None):
38
  """
39
  Formats the given message using a chat template.
40
 
 
51
  if api_kind == "openai":
52
  return messages
53
  elif api_kind == "hf":
54
+ TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name)
55
  return TOKENIZER.apply_chat_template(messages, tokenize=False)
56
  elif api_kind:
57
  raise ValueError("API is not supported")
58
 
59
 
60
+ def generate_hf(prompt: str, history: str, hf_model_name: str) -> Generator[str, None, str]:
61
  """
62
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
63
 
 
68
  Generator[str, None, str]: A generator yielding chunks of generated text.
69
  Returns a final string if an error occurs.
70
  """
71
+
72
 
73
+ HF_CLIENT = InferenceClient(
74
+ hf_model_name,
75
+ token=os.getenv("HF_TOKEN")
76
+ )
77
+
78
+ formatted_prompt = format_prompt(prompt, "hf", hf_model_name)
79
  formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
80
 
81
  try:
 
100
  raise gr.Error(f"Unhandled Exception: {str(e)}")
101
 
102
 
103
+ def generate_openai(prompt: str, history: str, model_name: str) -> Generator[str, None, str]:
104
  """
105
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
106
 
 
115
 
116
  try:
117
  stream = OAI_CLIENT.chat.completions.create(
118
+ #model=os.getenv("OPENAI_MODEL"),
119
+ model = model_name,
120
  messages=formatted_prompt,
121
  **OAI_GENERATE_KWARGS,
122
  stream=True