Tonic commited on
Commit
9bc49ef
1 Parent(s): 5ab0bbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -18,16 +18,17 @@ class OrcaChatBot:
18
  def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
19
  self.model = model
20
  self.tokenizer = tokenizer
21
- self.default_system_message = system_message
22
 
23
- def format_prompt(self, user_message, system_message):
24
- if system_message is None:
25
- system_message = self.default_system_message
 
26
  prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n<|im_start|>\nuser\n{user_message}<|im_end|>\nassistant\n"
27
  return prompt
28
 
29
- def predict(self, user_message, system_message=None, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
30
- prompt = self.format_prompt(user_message, system_message)
31
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
32
  input_ids = inputs["input_ids"].to(self.model.device)
33
 
@@ -44,7 +45,8 @@ class OrcaChatBot:
44
  return response
45
 
46
  def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
47
- response = Orca_bot.predict(user_message, system_message, temperature, max_new_tokens, top_p, repetition_penalty)
 
48
  return response
49
 
50
  Orca_bot = OrcaChatBot(model, tokenizer)
 
18
  def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
19
  self.model = model
20
  self.tokenizer = tokenizer
21
+ self.system_message = system_message
22
 
23
+ def set_system_message(self, new_system_message):
24
+ self.system_message = new_system_message
25
+
26
+ def format_prompt(self, user_message):
27
  prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n<|im_start|>\nuser\n{user_message}<|im_end|>\nassistant\n"
28
  return prompt
29
 
30
+ def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
31
+ prompt = self.format_prompt(user_message)
32
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
33
  input_ids = inputs["input_ids"].to(self.model.device)
34
 
 
45
  return response
46
 
47
  def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
48
+ Orca_bot.set_system_message(system_message)
49
+ response = Orca_bot.predict(user_message, temperature, max_new_tokens, top_p, repetition_penalty)
50
  return response
51
 
52
  Orca_bot = OrcaChatBot(model, tokenizer)