Tonic commited on
Commit
fd5c68e
1 Parent(s): e32564e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -53,16 +53,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
53
 
54
  model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
55
 
56
- class ChatBot:
57
- def __init__(self):
58
- self.history = []
59
 
60
  class ChatBot:
61
  def __init__(self):
62
  # Initialize the ChatBot class with an empty history
63
  self.history = []
64
 
65
- def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
66
  # Combine the user's input with the system prompt
67
  formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
68
 
@@ -70,7 +67,7 @@ class ChatBot:
70
  user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
71
 
72
  # Generate a response using the PEFT model
73
- response = peft_model.generate(input_ids=user_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
74
 
75
  # Decode the generated response to text
76
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
 
53
 
54
  model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
55
 
 
 
 
56
 
57
  class ChatBot:
58
  def __init__(self):
59
  # Initialize the ChatBot class with an empty history
60
  self.history = []
61
 
62
+ def predict(self, user_input, system_prompt="You are an expert medical analyst:" , example_instruction="produce a json", example_answer = "please dont make small talk "):
63
  # Combine the user's input with the system prompt
64
  formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
65
 
 
67
  user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
68
 
69
  # Generate a response using the PEFT model
70
+ response = model.generate(input_ids=user_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
71
 
72
  # Decode the generated response to text
73
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)