eswardivi commited on
Commit
ea9c0d3
1 Parent(s): 8ea3940

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -21,6 +21,10 @@ model = AutoModelForCausalLM.from_pretrained(
21
  "meta-llama/Meta-Llama-3-8B-Instruct", quantization_config=quantization_config, token=token
22
  )
23
  tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
 
 
 
 
24
 
25
  if torch.cuda.is_available():
26
  device = torch.device("cuda")
@@ -34,7 +38,7 @@ else:
34
 
35
 
36
  @spaces.GPU(duration=150)
37
- def chat(message, history, temperature,do_sample, top_k, max_tokens):
38
  start_time = time.time()
39
  chat = []
40
  for item in history:
@@ -52,9 +56,13 @@ def chat(message, history, temperature,do_sample, top_k, max_tokens):
52
  streamer=streamer,
53
  max_new_tokens=max_tokens,
54
  do_sample=True,
55
- top_k=top_k,
56
  temperature=temperature,
 
57
  )
 
 
 
 
58
  t = Thread(target=model.generate, kwargs=generate_kwargs)
59
  t.start()
60
 
@@ -86,14 +94,11 @@ demo = gr.ChatInterface(
86
  minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
87
  ),
88
  gr.Checkbox(label="Sampling",value=True),
89
- gr.Slider(
90
- minimum=1, maximum=10000, step=5, value=1000, label="top_k", render=False
91
- ),
92
  gr.Slider(
93
  minimum=128,
94
  maximum=4096,
95
  step=1,
96
- value=1024,
97
  label="Max new tokens",
98
  render=False,
99
  ),
 
21
  "meta-llama/Meta-Llama-3-8B-Instruct", quantization_config=quantization_config, token=token
22
  )
23
  tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
24
+ terminators = [
25
+ tokenizer.eos_token_id,
26
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
27
+ ]
28
 
29
  if torch.cuda.is_available():
30
  device = torch.device("cuda")
 
38
 
39
 
40
  @spaces.GPU(duration=150)
41
+ def chat(message, history, temperature,do_sample, max_tokens):
42
  start_time = time.time()
43
  chat = []
44
  for item in history:
 
56
  streamer=streamer,
57
  max_new_tokens=max_tokens,
58
  do_sample=True,
 
59
  temperature=temperature,
60
+ eos_token_id=terminators,
61
  )
62
+
63
+ if temperature == 0:
64
+ generate_kwargs['do_sample'] = False
65
+
66
  t = Thread(target=model.generate, kwargs=generate_kwargs)
67
  t.start()
68
 
 
94
  minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
95
  ),
96
  gr.Checkbox(label="Sampling",value=True),
 
 
 
97
  gr.Slider(
98
  minimum=128,
99
  maximum=4096,
100
  step=1,
101
+ value=512,
102
  label="Max new tokens",
103
  render=False,
104
  ),