vilarin commited on
Commit
4ed884e
1 Parent(s): 10efa15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -43,14 +43,23 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
43
  model = model.eval()
44
 
45
  @spaces.GPU()
46
- def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
47
-
 
 
 
 
 
 
 
 
 
48
  for resp, history in model.stream_chat(
49
  tokenizer,
50
  query = message,
51
  history = history,
52
  max_new_tokens = max_new_tokens,
53
- do_sample = True if temperature == 0 else False,
54
  top_p = top_p,
55
  top_k = top_k,
56
  temperature = temperature,
@@ -80,7 +89,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
80
  ),
81
  gr.Slider(
82
  minimum=128,
83
- maximum=2048,
84
  step=1,
85
  value=1024,
86
  label="Max New Tokens",
@@ -90,7 +99,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
90
  minimum=0.0,
91
  maximum=1.0,
92
  step=0.1,
93
- value=0.8,
94
  label="top_p",
95
  render=False,
96
  ),
@@ -106,7 +115,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
106
  minimum=0.0,
107
  maximum=2.0,
108
  step=0.1,
109
- value=1.0,
110
  label="Repetition penalty",
111
  render=False,
112
  ),
 
43
  model = model.eval()
44
 
45
  @spaces.GPU()
46
+ def stream_chat(
47
+ message: str,
48
+ history: list,
49
+ temperature: float = 0.8,
50
+ max_new_tokens: int = 1024,
51
+ top_p: float = 1.0,
52
+ top_k: int = 20,
53
+ penalty: float = 1.2
54
+ ):
55
+ print(f'message: {message}')
56
+ print(f'history: {history}')
57
  for resp, history in model.stream_chat(
58
  tokenizer,
59
  query = message,
60
  history = history,
61
  max_new_tokens = max_new_tokens,
62
+ do_sample = False if temperature == 0 else True,
63
  top_p = top_p,
64
  top_k = top_k,
65
  temperature = temperature,
 
89
  ),
90
  gr.Slider(
91
  minimum=128,
92
+ maximum=8192,
93
  step=1,
94
  value=1024,
95
  label="Max New Tokens",
 
99
  minimum=0.0,
100
  maximum=1.0,
101
  step=0.1,
102
+ value=1.0,
103
  label="top_p",
104
  render=False,
105
  ),
 
115
  minimum=0.0,
116
  maximum=2.0,
117
  step=0.1,
118
+ value=1.2,
119
  label="Repetition penalty",
120
  render=False,
121
  ),