Azure99 commited on
Commit
6bdb935
1 Parent(s): 6ae32bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -62,11 +62,13 @@ def generate(
62
  progress=gr.Progress(track_tqdm=True),
63
  ):
64
 
65
- input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt)), BOT_PREFIX)
66
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
67
  max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
68
  llm_result = llm.generate(**generation_kwargs)
69
- print(tokenizer.decode(llm_result.cpu()[0], skip_special_tokens=True))
 
 
70
 
71
  seed = random.randint(0, 2147483647)
72
  diffusion_pipe.to(device)
 
62
  progress=gr.Progress(track_tqdm=True),
63
  ):
64
 
65
+ input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
66
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
67
  max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
68
  llm_result = llm.generate(**generation_kwargs)
69
+ llm_result = BOT_PREFIX + tokenizer.decode(llm_result.cpu()[0], skip_special_tokens=True)
70
+ print(llm_result)
71
+ print(json.loads(llm_result))
72
 
73
  seed = random.randint(0, 2147483647)
74
  diffusion_pipe.to(device)