pszemraj commited on
Commit
b67934e
1 Parent(s): b4c0306

🎨 format

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (1) hide show
  1. converse.py +12 -11
converse.py CHANGED
@@ -5,7 +5,10 @@
5
  """
6
 
7
  import logging
8
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
 
 
 
9
  import pprint as pp
10
  import time
11
 
@@ -13,6 +16,7 @@ from grammar_improve import remove_trailing_punctuation
13
 
14
  from constrained_generation import constrained_generation
15
 
 
16
  def discussion(
17
  prompt_text: str,
18
  speaker: str,
@@ -65,10 +69,9 @@ def discussion(
65
  if verbose:
66
  print("overall prompt:\n")
67
  pp.pprint(this_prompt, indent=4)
68
- # call the model
69
- print("\n... generating...")
70
  if constrained_beam_search:
71
- logging.info("using constrained beam search")
72
  response = constrained_generation(
73
  prompt=this_prompt,
74
  pipeline=pipeline,
@@ -85,15 +88,13 @@ def discussion(
85
 
86
  bot_dialogue = consolidate_texts(
87
  name_resp=responder,
88
- model_resp=response.split(
89
- "\n"
90
- ),
91
  name_spk=speaker,
92
  verbose=verbose,
93
  print_debug=True,
94
  )
95
  else:
96
- logging.info("using sampling")
97
  bot_dialogue = gen_response(
98
  this_prompt,
99
  pipeline,
@@ -140,15 +141,15 @@ def gen_response(
140
  speaker: str,
141
  responder: str,
142
  timeout=45,
143
- min_length=4,
144
  max_length=48,
145
  top_p=0.95,
146
  top_k=20,
147
  temperature=0.5,
148
  full_text=False,
149
  num_return_sequences=1,
150
- length_penalty:float=0.8,
151
- repetition_penalty:float=3.5,
152
  no_repeat_ngram_size=2,
153
  device=-1,
154
  verbose=False,
 
5
  """
6
 
7
  import logging
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
11
+ )
12
  import pprint as pp
13
  import time
14
 
 
16
 
17
  from constrained_generation import constrained_generation
18
 
19
+
20
  def discussion(
21
  prompt_text: str,
22
  speaker: str,
 
69
  if verbose:
70
  print("overall prompt:\n")
71
  pp.pprint(this_prompt, indent=4)
72
+
 
73
  if constrained_beam_search:
74
+ logging.info("generating using constrained beam search ...")
75
  response = constrained_generation(
76
  prompt=this_prompt,
77
  pipeline=pipeline,
 
88
 
89
  bot_dialogue = consolidate_texts(
90
  name_resp=responder,
91
+ model_resp=response.split("\n"),
 
 
92
  name_spk=speaker,
93
  verbose=verbose,
94
  print_debug=True,
95
  )
96
  else:
97
+ logging.info("generating using sampling ...")
98
  bot_dialogue = gen_response(
99
  this_prompt,
100
  pipeline,
 
141
  speaker: str,
142
  responder: str,
143
  timeout=45,
144
+ min_length=12,
145
  max_length=48,
146
  top_p=0.95,
147
  top_k=20,
148
  temperature=0.5,
149
  full_text=False,
150
  num_return_sequences=1,
151
+ length_penalty: float = 0.8,
152
+ repetition_penalty: float = 3.5,
153
  no_repeat_ngram_size=2,
154
  device=-1,
155
  verbose=False,