Kims12 commited on
Commit
7a6d838
1 Parent(s): 9a5baed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -6
app.py CHANGED
@@ -3,20 +3,43 @@ from huggingface_hub import InferenceClient
3
  import os
4
  import openai
5
 
 
6
  MODELS = {
7
  "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
8
  "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
9
  "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
10
  "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
11
  "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
 
12
  }
13
 
 
 
 
14
  def get_client(model_name):
15
- model_id = MODELS[model_name]
16
- hf_token = os.getenv("HF_TOKEN")
17
- if not hf_token:
18
- raise ValueError("HF_TOKEN environment variable is required")
19
- return InferenceClient(model_id, token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def respond(
22
  message,
@@ -29,6 +52,12 @@ def respond(
29
  ):
30
  try:
31
  client = get_client(model_name)
 
 
 
 
 
 
32
  except ValueError as e:
33
  chat_history.append((message, str(e)))
34
  return chat_history
@@ -112,4 +141,4 @@ with gr.Blocks() as demo:
112
  clear_button.click(clear_conversation, outputs=chatbot, queue=False)
113
 
114
  if __name__ == "__main__":
115
- demo.launch()
 
3
  import os
4
  import openai
5
 
6
+ # 모델 목록에 chatgpt-4o-mini 추가
7
  MODELS = {
8
  "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
9
  "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
10
  "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
11
  "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
12
  "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
13
+ "chatgpt-4o-mini": "chatgpt-4o-mini", # OpenAI 모델 추가
14
  }
15
 
16
+ # OpenAI API 클라이언트 설정
17
+ openai.api_key = os.getenv("OPENAI_API_KEY")
18
+
19
  def get_client(model_name):
20
+ if model_name in MODELS:
21
+ model_id = MODELS[model_name]
22
+ if "chatgpt" in model_name: # OpenAI 모델인 경우
23
+ return None # InferenceClient 대신 None을 반환
24
+ hf_token = os.getenv("HF_TOKEN")
25
+ if not hf_token:
26
+ raise ValueError("HF_TOKEN environment variable is required")
27
+ return InferenceClient(model_id, token=hf_token)
28
+ else:
29
+ raise ValueError(f"Model {model_name} is not supported")
30
+
31
+ def call_api(content, system_message, max_tokens, temperature, top_p):
32
+ response = openai.ChatCompletion.create(
33
+ model="gpt-4o-mini",
34
+ messages=[
35
+ {"role": "system", "content": system_message},
36
+ {"role": "user", "content": content},
37
+ ],
38
+ max_tokens=max_tokens,
39
+ temperature=temperature,
40
+ top_p=top_p,
41
+ )
42
+ return response.choices[0]['message']['content']
43
 
44
  def respond(
45
  message,
 
52
  ):
53
  try:
54
  client = get_client(model_name)
55
+ if client is None and "chatgpt" in model_name: # OpenAI 모델의 경우
56
+ assistant_message = call_api(message, system_message, max_tokens, temperature, top_p)
57
+ chat_history.append((message, assistant_message))
58
+ yield chat_history
59
+ return
60
+
61
  except ValueError as e:
62
  chat_history.append((message, str(e)))
63
  return chat_history
 
141
  clear_button.click(clear_conversation, outputs=chatbot, queue=False)
142
 
143
  if __name__ == "__main__":
144
+ demo.launch()