Vipitis commited on
Commit
012c551
1 Parent(s): 791c9fd

refactor generation utils

Browse files
Files changed (3) hide show
  1. app.py +9 -57
  2. utils/__init__.py +5 -3
  3. utils/generation.py +58 -0
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import datasets
4
  import numpy as np
5
  import torch
6
- from threading import Thread
7
 
8
  from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
9
  from utils.html_utils import make_iframe, construct_embed
 
10
  PIPE = None
11
 
12
  intro_text = """
@@ -99,35 +99,6 @@ def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #
99
  print(f"loaded model {model_cp} as a pipline")
100
  return pipe
101
 
102
- def _run_generation(model_ctx:str, pipe, gen_kwargs:dict):
103
- """
104
- Text generation function
105
- Args:
106
- model_ctx (str): The context to start generation from.
107
- pipe (Pipeline): The pipeline to use for generation.
108
- gen_kwargs (dict): The generation kwargs.
109
- Returns:
110
- str: The generated text. (it iterates over time)
111
- """
112
- # Tokenize the model_context
113
- model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt")
114
-
115
- # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
116
- # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
117
- streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
118
- generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
119
- t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
120
- t.start()
121
-
122
- # Pull the generated text from the streamer, and update the model output.
123
- model_output = ""
124
- for new_text in streamer:
125
- # print("step", end="")
126
- model_output += new_text
127
- yield model_output
128
- streamer.on_finalized_text("stream reached the end.")
129
- return model_output #is this ever reached?
130
-
131
  def process_retn(retn):
132
  return retn.split(";")[0].strip()
133
 
@@ -167,7 +138,7 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
167
  else:
168
  raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
169
 
170
- generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
171
 
172
  retrns = []
173
  retrn_start_idx = orig_code.find("return")
@@ -189,14 +160,6 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
189
  return altered_code
190
 
191
 
192
- def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
193
- gen_kwargs = {}
194
- gen_kwargs["temperature"] = temperature
195
- gen_kwargs["max_new_tokens"] = max_new_tokens
196
- gen_kwargs["top_p"] = top_p
197
- gen_kwargs["repetition_penalty"] = repetition_penalty
198
- return gen_kwargs
199
-
200
  def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
201
  """
202
  Replaces the body of a function with a generated one.
@@ -223,27 +186,16 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
223
  func_node = funcs_list[func_id]
224
  print(f"using for generation: {func_node=}")
225
 
226
- generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
 
 
227
 
228
- func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
229
- identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
230
  body_node = func_node.child_by_field_name("body")
231
  body_start_idx, body_end_idx = node_str_idx(body_node)
232
- model_context = identifier_str # base case
233
-
234
- docstring = get_docstrings(func_node) #might be empty?
235
- if docstring:
236
- model_context = model_context + "\n" + docstring
237
- model_context = grab_before_comments(func_node) + model_context #prepend comments
238
- if prompt != "":
239
- model_context = f"//avialable functions: {','.join([n.child_by_field_name('declarator').text.decode() for n in funcs_list])}\n" + model_context #prepend available functions
240
- model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
241
- model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
242
- print(f"{model_context=}")
243
  # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
244
- generation = _run_generation(model_context, pipeline, generation_kwargs)
245
  for i in generation:
246
- print(f"{i=}")
247
  yield model_context + i #fix in between, do all the stuff in the end?
248
  generation = i[:] #seems to work
249
  print(f"{generation=}")
@@ -253,7 +205,7 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
253
  first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
254
  except IndexError:
255
  print("generation wasn't a full function.")
256
- altered_code = old_code[:func_start_idx] + model_context + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
257
  return altered_code
258
  altered_code = replace_function(func_node, first_gened_func)
259
  yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import datasets
4
  import numpy as np
5
  import torch
 
6
 
7
  from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
8
  from utils.html_utils import make_iframe, construct_embed
9
+ from utils.generation import combine_generation_kwargs, stream_generation, construct_model_context
10
  PIPE = None
11
 
12
  intro_text = """
 
99
  print(f"loaded model {model_cp} as a pipline")
100
  return pipe
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def process_retn(retn):
103
  return retn.split(";")[0].strip()
104
 
 
138
  else:
139
  raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
140
 
141
+ generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
142
 
143
  retrns = []
144
  retrn_start_idx = orig_code.find("return")
 
160
  return altered_code
161
 
162
 
 
 
 
 
 
 
 
 
163
  def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
164
  """
165
  Replaces the body of a function with a generated one.
 
186
  func_node = funcs_list[func_id]
187
  print(f"using for generation: {func_node=}")
188
 
189
+ generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
190
+ model_context = construct_model_context(func_node, prompt=prompt)
191
+ print(f"{model_context=}")
192
 
 
 
193
  body_node = func_node.child_by_field_name("body")
194
  body_start_idx, body_end_idx = node_str_idx(body_node)
 
 
 
 
 
 
 
 
 
 
 
195
  # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
196
+ generation = stream_generation(model_context, pipeline, generation_kwargs)
197
  for i in generation:
198
+ # print(f"{i=}")
199
  yield model_context + i #fix in between, do all the stuff in the end?
200
  generation = i[:] #seems to work
201
  print(f"{generation=}")
 
205
  first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
206
  except IndexError:
207
  print("generation wasn't a full function.")
208
+ altered_code = old_code[:body_start_idx] + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
209
  return altered_code
210
  altered_code = replace_function(func_node, first_gened_func)
211
  yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
utils/__init__.py CHANGED
@@ -1,7 +1,9 @@
1
- from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char)
2
  from .html_utils import (make_iframe, make_script, construct_embed)
 
3
 
4
- tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char"]
5
  html_funcs = ["make_iframe", "make_script", "construct_embed"]
 
6
 
7
- __all__ = tree_funcs + html_funcs
 
1
+ from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree)
2
  from .html_utils import (make_iframe, make_script, construct_embed)
3
+ from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
4
 
5
+ tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree"]
6
  html_funcs = ["make_iframe", "make_script", "construct_embed"]
7
+ gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
8
 
9
+ __all__ = tree_funcs + html_funcs + gen_funcs
utils/generation.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextIteratorStreamer
2
+ from threading import Thread
3
+ from utils.tree_utils import get_docstrings, grab_before_comments
4
+
5
+ def combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
6
+ """
7
+ Combines the generation kwargs into a single dict.
8
+ """
9
+ gen_kwargs = {}
10
+ gen_kwargs["temperature"] = temperature
11
+ gen_kwargs["max_new_tokens"] = max_new_tokens
12
+ gen_kwargs["top_p"] = top_p
13
+ gen_kwargs["repetition_penalty"] = repetition_penalty
14
+ return gen_kwargs
15
+
16
+
17
+ def stream_generation(prompt:str, pipe, gen_kwargs:dict):
18
+ """
19
+ Text generation function
20
+ Args:
21
+ prompt (str): The context to start generation from.
22
+ pipe (Pipeline): The pipeline to use for generation.
23
+ gen_kwargs (dict): The generation kwargs.
24
+ Returns:
25
+ str: The generated text. (it iterates over time)
26
+ """
27
+ # Tokenize the model_context
28
+ model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
29
+
30
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
31
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
32
+ streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
33
+ generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
34
+ t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
35
+ t.start()
36
+
37
+ # Pull the generated text from the streamer, and update the model output.
38
+ model_output = ""
39
+ for new_text in streamer:
40
+ # print("step", end="")
41
+ model_output += new_text
42
+ yield model_output
43
+ streamer.on_finalized_text("stream reached the end.")
44
+ return model_output #is this ever reached?
45
+
46
+ def construct_model_context(func_node, prompt="") -> str:
47
+ """
48
+ Constructs the model context from a function node.
49
+ """
50
+ model_context = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
51
+ docstring = get_docstrings(func_node) #might be empty?
52
+ if docstring:
53
+ model_context = model_context + "\n" + docstring
54
+ model_context = grab_before_comments(func_node) + model_context #prepend comments
55
+ if prompt != "":
56
+ model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
57
+ model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
58
+ return model_context