sohojoe commited on
Commit
552fb1e
1 Parent(s): 92b10f9

basic openai support

Browse files
Files changed (3) hide show
  1. chat_service.py +26 -11
  2. environment.yml +1 -0
  3. requirements.txt +2 -1
chat_service.py CHANGED
@@ -1,35 +1,49 @@
1
  import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
  # from huggingface_hub.inference_api import InferenceApi
6
 
7
  class ChatService:
8
- def __init__(self, api="huggingface", repo_id = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
 
9
  self._api = api
10
  self._device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
 
12
- if self._api=="huggingface":
13
- self._tokenizer = AutoTokenizer.from_pretrained(repo_id)
14
- self._model = AutoModelForCausalLM.from_pretrained(repo_id,torch_dtype=torch.float16)
15
- # self._model = AutoModelForCausalLM.from_pretrained(repo_id).half()
 
 
 
 
 
 
16
  self._model.eval().to(self._device)
17
  else:
18
  raise Exception(f"Unknown API: {self._api}")
19
 
20
- self._system_prompt = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n"
21
- self._user_name = "<|prompter|>"
22
- self._agent_name = "<|assistant|>"
23
  self.reset()
24
 
25
  def reset(self):
26
  self._user_history = []
27
  self._agent_history = []
28
- self._full_history = self._user_history if self._user_history else ""
 
 
 
29
 
30
 
31
  def _chat(self, prompt):
32
- if self._api=="huggingface":
 
 
 
 
 
 
33
  tokens = self._tokenizer.encode(prompt, return_tensors="pt", padding=True)
34
  tokens = tokens.to(self._device)
35
  outputs = self._model.generate(
@@ -42,7 +56,6 @@ class ChatService:
42
  pad_token_id=self._tokenizer.eos_token_id,
43
  )
44
  agent_response = self._tokenizer.decode(outputs[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
45
-
46
  else:
47
  raise Exception(f"API not implemented: {self._api}")
48
  return agent_response
@@ -52,8 +65,10 @@ class ChatService:
52
  self._full_history += f"{self._user_name}: {prompt}\n"
53
  else:
54
  self._full_history += f"{prompt}\n"
 
55
  self._user_history.append(prompt)
56
  agent_response = self._chat(self._full_history)
 
57
  if self._agent_name:
58
  self._full_history += f"{self._agent_name}: {agent_response}\n"
59
  else:
 
1
  import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import openai
5
 
6
  # from huggingface_hub.inference_api import InferenceApi
7
 
8
  class ChatService:
9
+ def __init__(self, api="openai", model_id = "gpt-3.5-turbo"):
10
+ # def __init__(self, api="huggingface", model_id = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
11
  self._api = api
12
  self._device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
+ if self._api=="openai":
15
+ openai.api_key = os.getenv("OPENAI_API_KEY")
16
+ self._model_id = model_id
17
+ elif self._api=="huggingface":
18
+ self._system_prompt = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n"
19
+ self._user_name = "<|prompter|>"
20
+ self._agent_name = "<|assistant|>"
21
+ self._tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ self._model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float16)
23
+ # self._model = AutoModelForCausalLM.from_pretrained(model_id).half()
24
  self._model.eval().to(self._device)
25
  else:
26
  raise Exception(f"Unknown API: {self._api}")
27
 
 
 
 
28
  self.reset()
29
 
30
  def reset(self):
31
  self._user_history = []
32
  self._agent_history = []
33
+ self._full_history = self._system_prompt if self._system_prompt else ""
34
+ self._messages = []
35
+ if self._system_prompt:
36
+ self._messages.append({"role": "system", "content": self._system_prompt})
37
 
38
 
39
  def _chat(self, prompt):
40
+ if self._api=="openai":
41
+ response = openai.ChatCompletion.create(
42
+ model=self._model_id,
43
+ messages=self._messages,
44
+ )
45
+ agent_response = response['choices'][0]['message']['content']
46
+ elif self._api=="huggingface":
47
  tokens = self._tokenizer.encode(prompt, return_tensors="pt", padding=True)
48
  tokens = tokens.to(self._device)
49
  outputs = self._model.generate(
 
56
  pad_token_id=self._tokenizer.eos_token_id,
57
  )
58
  agent_response = self._tokenizer.decode(outputs[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
 
59
  else:
60
  raise Exception(f"API not implemented: {self._api}")
61
  return agent_response
 
65
  self._full_history += f"{self._user_name}: {prompt}\n"
66
  else:
67
  self._full_history += f"{prompt}\n"
68
+ self._messages.append({"role": "user", "content": prompt})
69
  self._user_history.append(prompt)
70
  agent_response = self._chat(self._full_history)
71
+ self._messages.append({"role": "assistant", "content": agent_response})
72
  if self._agent_name:
73
  self._full_history += f"{self._agent_name}: {agent_response}\n"
74
  else:
environment.yml CHANGED
@@ -19,3 +19,4 @@ dependencies:
19
  - open_clip_torch
20
  - vosk
21
  - transformers
 
 
19
  - open_clip_torch
20
  - vosk
21
  - transformers
22
+ - openai
requirements.txt CHANGED
@@ -14,4 +14,5 @@ pydub
14
  torch
15
  numpy
16
  open_clip_torch
17
- transformers
 
 
14
  torch
15
  numpy
16
  open_clip_torch
17
+ transformers
18
+ openai