afrizalha commited on
Commit
579abe4
1 Parent(s): c6dc7c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -25
app.py CHANGED
@@ -17,39 +17,24 @@ template = """<|im_start|>system
17
  def generate(query, temp, top_p):
18
  inputs = template.format(prompt=query)
19
  inputs = tokenizer([inputs], return_tensors="pt").to(model.device)
20
-
21
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
22
-
23
- generation_kwargs = dict(
24
  inputs=inputs.input_ids,
25
  max_new_tokens=1024,
26
  do_sample=True,
27
  temperature=temp,
28
- top_p=top_p,
29
- streamer=streamer,
30
- )
31
-
32
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
33
- thread.start()
34
-
35
- return streamer
36
-
37
  with gr.Blocks(theme=gr.themes.Soft()) as app:
38
  input = gr.Textbox(label="Prompt", value="Pripun kulo saged nyinaoni Basa Jawa kanthi sae?")
39
  output = gr.Textbox(label="Response", scale=2)
40
  temp = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
41
  top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
42
-
43
- def stream(query, temp, top_p):
44
- for token in generate(query, temp, top_p):
45
- yield token
46
-
47
  gr.Interface(
48
- fn=stream,
49
- inputs=[input, temp, top_p],
50
- outputs=[output],
51
- allow_flagging="never",
52
- title="Bakpia-V1 (Streaming)",
53
- )
54
-
55
  app.launch()
 
17
  def generate(query, temp, top_p):
18
  inputs = template.format(prompt=query)
19
  inputs = tokenizer([inputs], return_tensors="pt").to(model.device)
20
+ outputs = model.generate(
 
 
 
21
  inputs=inputs.input_ids,
22
  max_new_tokens=1024,
23
  do_sample=True,
24
  temperature=temp,
25
+ top_p=top_p)
26
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
27
+ return outputs
 
 
 
 
 
 
28
  with gr.Blocks(theme=gr.themes.Soft()) as app:
29
  input = gr.Textbox(label="Prompt", value="Pripun kulo saged nyinaoni Basa Jawa kanthi sae?")
30
  output = gr.Textbox(label="Response", scale=2)
31
  temp = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
32
  top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
 
 
 
 
 
33
  gr.Interface(
34
+ fn=generate,
35
+ inputs=[input,temp,top_p],
36
+ outputs=[output],
37
+ allow_flagging="never",
38
+ title="Bakpia-V1",
39
+ )
 
40
  app.launch()