mttrz commited on
Commit
d8b3406
1 Parent(s): d0c35a3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +1 -10
main.py CHANGED
@@ -18,20 +18,14 @@ class Item(BaseModel):
18
  repetition_penalty: float = 1.0
19
 
20
  def format_prompt(message, history):
21
- prompt = "<s>"
22
- for user_prompt, bot_response in history:
23
- prompt += f"[INST] {user_prompt} [/INST]"
24
- prompt += f" {bot_response}</s> "
25
- prompt += f"[INST] {message} [/INST]"
26
  return prompt
27
 
28
  def generate(item: Item):
29
- print('Sto partendo con la funzione')
30
  temperature = float(item.temperature)
31
  if temperature < 1e-2:
32
  temperature = 1e-2
33
  top_p = float(item.top_p)
34
- print('kwargs')
35
  generate_kwargs = dict(
36
  temperature=temperature,
37
  max_new_tokens=item.max_new_tokens,
@@ -40,15 +34,12 @@ def generate(item: Item):
40
  do_sample=True,
41
  seed=42,
42
  )
43
- print('Formatto il prompt')
44
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
45
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
46
- print('Parte lo stream')
47
  output = ""
48
 
49
  for response in stream:
50
  output += response.token.text
51
- print('HO QUASI FINITO')
52
  return output
53
 
54
  @app.post("/generate/")
 
18
  repetition_penalty: float = 1.0
19
 
20
  def format_prompt(message, history):
21
+ prompt = message
 
 
 
 
22
  return prompt
23
 
24
  def generate(item: Item):
 
25
  temperature = float(item.temperature)
26
  if temperature < 1e-2:
27
  temperature = 1e-2
28
  top_p = float(item.top_p)
 
29
  generate_kwargs = dict(
30
  temperature=temperature,
31
  max_new_tokens=item.max_new_tokens,
 
34
  do_sample=True,
35
  seed=42,
36
  )
 
37
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
38
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
 
39
  output = ""
40
 
41
  for response in stream:
42
  output += response.token.text
 
43
  return output
44
 
45
  @app.post("/generate/")