tkdehf2 commited on
Commit
7662268
โ€ข
1 Parent(s): 4dd2646

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -1,12 +1,19 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
3
 
4
  def generate_diary(keywords):
5
- # InstructGPT ๋ชจ๋ธ ํ™œ์šฉ
6
- generator = pipeline('text-generation', model='anthropic/InstructGPT')
7
  prompt = f"์˜ค๋Š˜์˜ ์ผ๊ธฐ:\n\n{', '.join(keywords.split(','))}์— ๋Œ€ํ•œ ์ผ๊ธฐ๋ฅผ ์จ๋ด…์‹œ๋‹ค."
8
- output = generator(prompt, max_length=500, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95, num_beams=5, no_repeat_ngram_size=2)[0]['generated_text']
9
- return output
 
 
 
 
10
 
11
  def app():
12
  with gr.Blocks() as demo:
 
1
  import gradio as gr
2
+ from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
3
+
4
+ # KoGPT2 ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ „์—ญ ๋ณ€์ˆ˜๋กœ ๋ฏธ๋ฆฌ ๋กœ๋“œ
5
+ model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2")
6
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2")
7
 
8
  def generate_diary(keywords):
9
+ # ํ‚ค์›Œ๋“œ ๊ธฐ๋ฐ˜ ์ผ๊ธฐ ์ƒ์„ฑ
 
10
  prompt = f"์˜ค๋Š˜์˜ ์ผ๊ธฐ:\n\n{', '.join(keywords.split(','))}์— ๋Œ€ํ•œ ์ผ๊ธฐ๋ฅผ ์จ๋ด…์‹œ๋‹ค."
11
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
12
+ output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95, num_beams=5, no_repeat_ngram_size=2)
13
+
14
+ # ์ƒ์„ฑ๋œ ์ผ๊ธฐ ํ…์ŠคํŠธ ๋ฐ˜ํ™˜
15
+ diary = tokenizer.decode(output[0], skip_special_tokens=True)
16
+ return diary
17
 
18
  def app():
19
  with gr.Blocks() as demo: