hunkim commited on
Commit
eb7e846
ยท
1 Parent(s): da4c6e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -4,33 +4,31 @@ import streamlit as st
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
- '''
8
  tokenizer = AutoTokenizer.from_pretrained(
9
- 'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',
10
  bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[PAD]', mask_token='[MASK]'
11
  )
12
 
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
- 'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',
17
  pad_token_id=tokenizer.eos_token_id,
18
  torch_dtype=torch.float16, low_cpu_mem_usage=False
19
  ).to(device=device, non_blocking=True)
20
  _ = model.eval()
21
- '''
22
  print("Model loading done!")
23
 
24
  def gpt(prompt):
25
- return prompt
26
- '''
27
  with torch.no_grad():
28
  tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
29
  gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=256)
30
  generated = tokenizer.batch_decode(gen_tokens)[0]
31
 
32
  return generated
33
- '''
34
 
35
  #prompts
36
  st.title("์—ฌ๋Ÿฌ๋ถ„๋“ค์˜ ๋ฌธ์žฅ์„ ์™„์„ฑํ•ด์ค๋‹ˆ๋‹ค. ๐Ÿค–")
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+
8
  tokenizer = AutoTokenizer.from_pretrained(
9
+ 'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b', cache_dir='./model_dir/',
10
  bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[PAD]', mask_token='[MASK]'
11
  )
12
 
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ 'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',cache_dir='./model_dir/',
17
  pad_token_id=tokenizer.eos_token_id,
18
  torch_dtype=torch.float16, low_cpu_mem_usage=False
19
  ).to(device=device, non_blocking=True)
20
  _ = model.eval()
21
+
22
  print("Model loading done!")
23
 
24
  def gpt(prompt):
 
 
25
  with torch.no_grad():
26
  tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
27
  gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=256)
28
  generated = tokenizer.batch_decode(gen_tokens)[0]
29
 
30
  return generated
31
+
32
 
33
  #prompts
34
  st.title("์—ฌ๋Ÿฌ๋ถ„๋“ค์˜ ๋ฌธ์žฅ์„ ์™„์„ฑํ•ด์ค๋‹ˆ๋‹ค. ๐Ÿค–")