hunkim commited on
Commit
94a6c27
β€’
1 Parent(s): ff78a89

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*-coding:utf-8-*-
2
+ import streamlit as st
3
+ # code from https://huggingface.co/kakaobrain/kogpt
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained(
8
+ 'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',
9
+ bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[PAD]', mask_token='[MASK]'
10
+ )
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ 'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',
13
+ pad_token_id=tokenizer.eos_token_id,
14
+ torch_dtype=torch.float16, low_cpu_mem_usage=False
15
+ ).to(device='cpu', non_blocking=True)
16
+ _ = model.eval()
17
+
18
+ print("Model loading done!")
19
+
20
+ def gpt(prompt):
21
+ with torch.no_grad():
22
+ tokens = tokenizer.encode(prompt, return_tensors='pt').to(device='cpu', non_blocking=True)
23
+ gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=256)
24
+ generated = tokenizer.batch_decode(gen_tokens)[0]
25
+
26
+ return generated
27
+
28
+ #prompts
29
+ st.title("μ—¬λŸ¬λΆ„λ“€μ˜ λ¬Έμž₯을 μ™„μ„±ν•΄μ€λ‹ˆλ‹€. πŸ€–")
30
+ st.markdown("카카였 gpt μ‚¬μš©ν•©λ‹ˆλ‹€.")
31
+ st.subheader("λͺ‡κ°€μ§€ 예제: ")
32
+ example_1_str = "였늘의 λ‚ μ”¨λŠ” λ„ˆλ¬΄ λˆˆλΆ€μ‹œλ‹€. 내일은 "
33
+ example_2_str = "μš°λ¦¬λŠ” 행볡을 μ–Έμ œλ‚˜ κ°ˆλ§ν•˜μ§€λ§Œ 항상 "
34
+ example_1 = st.button(example_1_str)
35
+ example_2 = st.button(example_2_str)
36
+ textbox = st.text_area('μ˜€λŠ˜μ€ 아름닀움을 ν–₯ν•΄ 달리고 ', '',height=100, max_chars=500 )
37
+ button = st.button('생성:')
38
+ # output
39
+ st.subheader("κ²°κ³Όκ°’: ")
40
+ if example_1:
41
+ with st.spinner('In progress.......'):
42
+ output_text = gpt(example_1_str)
43
+ st.markdown("## "+output_text)
44
+ if example_2:
45
+ with st.spinner('In progress.......'):
46
+ output_text = gpt(example_2_str)
47
+ st.markdown("## "+output_text)
48
+ if button:
49
+ with st.spinner('In progress.......'):
50
+ if textbox:
51
+ output_text = gpt(textbox)
52
+ else:
53
+ output_text = " "
54
+ st.markdown("## "+output_text)