Ashishkr commited on
Commit
91e30ca
1 Parent(s): cb1f85a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +15 -19
model.py CHANGED
@@ -68,30 +68,26 @@ def run(message: str,
68
  max_new_tokens: int = 256,
69
  temperature: float = 0.8,
70
  top_p: float = 0.95,
71
- top_k: int = 50) -> Iterator[str]:
72
  prompt = get_prompt(message, chat_history, system_prompt)
73
  inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
74
 
75
- streamer = TextIteratorStreamer(tokenizer,
76
- timeout=10.,
77
- skip_prompt=True,
78
- skip_special_tokens=True)
79
- generate_kwargs = dict(
80
- inputs,
81
- streamer=streamer,
82
- max_new_tokens=max_new_tokens,
83
  do_sample=True,
84
  top_p=top_p,
85
  top_k=top_k,
86
  temperature=temperature,
87
- num_beams=1,
88
  )
89
- t = Thread(target=model.generate, kwargs=generate_kwargs)
90
- t.start()
91
-
92
- outputs = []
93
- for text in streamer:
94
- outputs.append(text)
95
- if "instruction:" in text:
96
- break
97
- yield ''.join(outputs)
 
68
  max_new_tokens: int = 256,
69
  temperature: float = 0.8,
70
  top_p: float = 0.95,
71
+ top_k: int = 50) -> str:
72
  prompt = get_prompt(message, chat_history, system_prompt)
73
  inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
74
 
75
+ # Generate tokens using the model
76
+ output = model.generate(
77
+ input_ids=inputs['input_ids'],
78
+ attention_mask=inputs['attention_mask'],
79
+ max_length=max_new_tokens + inputs['input_ids'].shape[-1],
 
 
 
80
  do_sample=True,
81
  top_p=top_p,
82
  top_k=top_k,
83
  temperature=temperature,
84
+ num_beams=1
85
  )
86
+
87
+ # Decode the output tokens back to a string
88
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
89
+
90
+ # Remove everything including and after "instruct: "
91
+ output_text = output_text.split("instruct: ")[0]
92
+
93
+ return output_text