Josh Nguyen commited on
Commit
d38f5f1
1 Parent(s): 422252e

Fix a bug in generate_text

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -34,12 +34,10 @@ def generate_text(prompt: str,
34
  temperature: float = 0.5,
35
  top_p: float = 0.95,
36
  top_k: int = 50) -> str:
37
-
38
  # Encode the prompt
39
  inputs = tokenizer([prompt],
40
  return_tensors='pt',
41
  add_special_tokens=False).to(DEVICE)
42
-
43
  # Prepare arguments for generation
44
  input_length = inputs["input_ids"].shape[-1]
45
  max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
@@ -56,8 +54,8 @@ def generate_text(prompt: str,
56
  skip_prompt=True,
57
  skip_special_tokens=True)
58
  generation_kwargs = dict(
59
- inputs=inputs,
60
- streamer=inputs,
61
  max_new_tokens=max_new_tokens,
62
  do_sample=True,
63
  top_p=top_p,
@@ -65,12 +63,10 @@ def generate_text(prompt: str,
65
  temperature=temperature,
66
  num_beams=1,
67
  )
68
-
69
  # Generate text
70
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
  thread.start()
72
-
73
- generated_text = ""
74
  for new_text in streamer:
75
  generated_text += new_text
76
  return generated_text
 
34
  temperature: float = 0.5,
35
  top_p: float = 0.95,
36
  top_k: int = 50) -> str:
 
37
  # Encode the prompt
38
  inputs = tokenizer([prompt],
39
  return_tensors='pt',
40
  add_special_tokens=False).to(DEVICE)
 
41
  # Prepare arguments for generation
42
  input_length = inputs["input_ids"].shape[-1]
43
  max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
 
54
  skip_prompt=True,
55
  skip_special_tokens=True)
56
  generation_kwargs = dict(
57
+ **inputs,
58
+ streamer=streamer,
59
  max_new_tokens=max_new_tokens,
60
  do_sample=True,
61
  top_p=top_p,
 
63
  temperature=temperature,
64
  num_beams=1,
65
  )
 
66
  # Generate text
67
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
68
  thread.start()
69
+ generated_text = prompt
 
70
  for new_text in streamer:
71
  generated_text += new_text
72
  return generated_text