Vipitis commited on
Commit
3c0111d
1 Parent(s): 6631a55

streaming text generation in working shape

Browse files
Files changed (1) hide show
  1. app.py +42 -5
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import datasets
4
  import asyncio
5
  import numpy as np
6
  import torch
 
7
 
8
  def make_script(shader_code):
9
  # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
@@ -274,6 +275,7 @@ outro_text ="""
274
  - [~] include some context for prompt (title, comments before a functions) - now works with the first comment inside a function body (has to be first)
275
  - [] gradio examples
276
  - [] use GPU if available, respect memory restrictions.
 
277
 
278
  ### Notes:
279
  - this is meant as a resource to show code generation for a "creative" task.
@@ -342,6 +344,34 @@ def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #
342
  print(f"loaded model {model_cp} as a pipline")
343
  return pipe
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  def process_retn(retn):
347
  return retn.split(";")[0].strip()
@@ -458,7 +488,12 @@ def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens,
458
  # print(second_child.text.decode())
459
  model_context += " { \n " + second_child.text.decode()
460
  print(f"{model_context=}")
461
- generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
 
 
 
 
 
462
  print(f"{generation=}")
463
  ctx_with_generation = model_context + generation
464
  print(f"{ctx_with_generation=}")
@@ -474,7 +509,9 @@ def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens,
474
  generated_body = first_gened_func.child_by_field_name("body").text.decode()
475
  print(f"{generated_body=}")
476
  altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:]
477
- return altered_code, pipeline
 
 
478
 
479
  def add_history(func_id, orig_rtn, gened_rtn, history):
480
  # is this a list? or a JSON dict?
@@ -524,7 +561,7 @@ with gr.Blocks() as site:
524
  with column_2:
525
  top_p = gr.Slider(
526
  label="Top-p (nucleus sampling)",
527
- value=0.30,
528
  minimum=0.0,
529
  maximum=1,
530
  step=0.05,
@@ -563,4 +600,4 @@ with gr.Blocks() as site:
563
  gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe])
564
  sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) # to update this after generation, so spans aren't messed up
565
  sample_code.change(fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) #twice could cause issues, find better ways.
566
- site.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import datasets
4
  import asyncio
5
  import numpy as np
6
  import torch
7
+ from threading import Thread
8
 
9
  def make_script(shader_code):
10
  # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
 
275
  - [~] include some context for prompt (title, comments before a functions) - now works with the first comment inside a function body (has to be first)
276
  - [] gradio examples
277
  - [] use GPU if available, respect memory restrictions.
278
+ - [~] stream model generation (maybe in a new window?) - WIP for body gen right now -> janky solution works.
279
 
280
  ### Notes:
281
  - this is meant as a resource to show code generation for a "creative" task.
 
344
  print(f"loaded model {model_cp} as a pipline")
345
  return pipe
346
 
347
+ def _run_generation(model_ctx:str, pipe, gen_kwargs:dict):
348
+ """
349
+ Text generation function
350
+ Args:
351
+ model_ctx (str): The context to start generation from.
352
+ pipe (Pipeline): The pipeline to use for generation.
353
+ gen_kwargs (dict): The generation kwargs.
354
+ Returns:
355
+ str: The generated text. (it iterates over time)
356
+ """
357
+ # Tokenize the model_context
358
+ model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt")
359
+
360
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
361
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
362
+ streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
363
+ generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
364
+ t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
365
+ t.start()
366
+
367
+ # Pull the generated text from the streamer, and update the model output.
368
+ model_output = ""
369
+ for new_text in streamer:
370
+ # print("step", end="")
371
+ model_output += new_text
372
+ yield model_output
373
+ streamer.on_finalized_text("stream reached the end.")
374
+ return model_output #is this ever reached?
375
 
376
  def process_retn(retn):
377
  return retn.split(";")[0].strip()
 
488
  # print(second_child.text.decode())
489
  model_context += " { \n " + second_child.text.decode()
490
  print(f"{model_context=}")
491
+ # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
492
+ generation = _run_generation(model_context, pipeline, generation_kwargs)
493
+ for i in generation:
494
+ print(f"{i=}")
495
+ yield model_context + i, pipeline #fix in between, do all the stuff in the end?
496
+ generation = i[:] #seems to work
497
  print(f"{generation=}")
498
  ctx_with_generation = model_context + generation
499
  print(f"{ctx_with_generation=}")
 
509
  generated_body = first_gened_func.child_by_field_name("body").text.decode()
510
  print(f"{generated_body=}")
511
  altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:]
512
+ print(f"{altered_code=}") #we get here successfully
513
+ yield altered_code, pipeline #yield once so it updates? -> works... gg
514
+ return altered_code, pipeline #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook
515
 
516
  def add_history(func_id, orig_rtn, gened_rtn, history):
517
  # is this a list? or a JSON dict?
 
561
  with column_2:
562
  top_p = gr.Slider(
563
  label="Top-p (nucleus sampling)",
564
+ value=0.85,
565
  minimum=0.0,
566
  maximum=1,
567
  step=0.05,
 
600
  gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe])
601
  sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) # to update this after generation, so spans aren't messed up
602
  sample_code.change(fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) #twice could cause issues, find better ways.
603
+ site.launch(enable_queue=True)