tkdehf2 commited on
Commit
8e5abcb
โ€ข
1 Parent(s): e59909f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -7,8 +7,9 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
 
8
  def generate_diary(keywords):
9
  # ํ‚ค์›Œ๋“œ ๊ธฐ๋ฐ˜ fine-tuning
10
- input_ids = tokenizer.encode(" ".join(keywords), return_tensors="pt")
11
- output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95, num_beams=5)
 
12
 
13
  # ์ƒ์„ฑ๋œ ์ผ๊ธฐ ํ…์ŠคํŠธ ๋ฐ˜ํ™˜
14
  diary = tokenizer.decode(output[0], skip_special_tokens=True)
 
7
 
8
  def generate_diary(keywords):
9
  # ํ‚ค์›Œ๋“œ ๊ธฐ๋ฐ˜ fine-tuning
10
+ prompt = f"์˜ค๋Š˜์˜ ์ผ๊ธฐ:\n\n{' '.join(keywords.split(','))}์— ๊ด€ํ•œ ์ผ๊ธฐ๋ฅผ ์จ๋ด…์‹œ๋‹ค."
11
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
12
+ output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95, num_beams=5, no_repeat_ngram_size=2)
13
 
14
  # ์ƒ์„ฑ๋œ ์ผ๊ธฐ ํ…์ŠคํŠธ ๋ฐ˜ํ™˜
15
  diary = tokenizer.decode(output[0], skip_special_tokens=True)