LennardZuendorf commited on
Commit
326ad4b
1 Parent(s): 2b7cdd8

fix: bugfix mistral generate method

Browse files
Files changed (1) hide show
  1. model/mistral.py +7 -2
model/mistral.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
 
8
  # internal imports
9
  from utils import modelling as mdl
 
10
 
11
  # global model and tokenizer instance (created on inital build)
12
  device = mdl.get_device()
@@ -91,6 +92,9 @@ def format_answer(answer: str):
91
  # empty answer string
92
  formatted_answer = ""
93
 
 
 
 
94
  # extracting text after INST tokens
95
  parts = answer.split("[/INST]")
96
  if len(parts) >= 3:
@@ -106,10 +110,11 @@ def format_answer(answer: str):
106
  def respond(prompt: str):
107
 
108
  # tokenizing inputs and configuring model
109
- input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"]
110
 
111
  # generating text with tokenized input, returning output
112
- output_ids = MODEL.generate(input_ids, max_new_tokens=50, generation_config=CONFIG)
113
  output_text = TOKENIZER.batch_decode(output_ids)
 
114
 
115
  return format_answer(output_text)
 
7
 
8
  # internal imports
9
  from utils import modelling as mdl
10
+ from utils import formatting as fmt
11
 
12
  # global model and tokenizer instance (created on inital build)
13
  device = mdl.get_device()
 
92
  # empty answer string
93
  formatted_answer = ""
94
 
95
+ if type(answer) == list:
96
+ answer = fmt.format_output_text
97
+
98
  # extracting text after INST tokens
99
  parts = answer.split("[/INST]")
100
  if len(parts) >= 3:
 
110
  def respond(prompt: str):
111
 
112
  # tokenizing inputs and configuring model
113
+ input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)
114
 
115
  # generating text with tokenized input, returning output
116
+ output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
117
  output_text = TOKENIZER.batch_decode(output_ids)
118
+ output_text.fo
119
 
120
  return format_answer(output_text)