ldhldh commited on
Commit
082349c
1 Parent(s): e65ea05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -60
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from threading import Thread
2
-
3
  import torch
4
  import gradio as gr
5
  import re
@@ -8,56 +8,18 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
8
  print("Running on device:", torch_device)
9
  print("CPU threads:", torch.get_num_threads())
10
 
11
- peft_model_id = "ldhldh/1.3_40kstep"
12
-
13
- #peft_model_id = "ldhldh/polyglot-ko-1.3b_lora_big_tern_30kstep"
14
- # 20k or 30k
15
- #18k > 상대의 말까지 하는 이슈가 있음
16
- #8k > 약간 아쉬운가?
17
-
18
- base_model = AutoModelForCausalLM.from_pretrained("Skyranch/KoAlpaca-Polyglot-12.8B-ggml-model-f16", hf=True, model_type='gpt_neox')
19
- tokenizer = AutoTokenizer.from_pretrained("beomi/KoAlpaca-Polyglot-12.8B")
20
-
21
- #base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/polyglot-ko-3.8b")
22
- #tokenizer = AutoTokenizer.from_pretrained("EleutherAI/polyglot-ko-3.8b")
23
- base_model.eval()
24
- base_model.config.use_cache = True
25
-
26
-
27
- #model = PeftModel.from_pretrained(base_model, peft_model_id, device_map="auto")
28
- #model.eval()
29
- #model.config.use_cache = True
30
-
31
- def gen(x, top_p, top_k, temperature, max_new_tokens, repetition_penalty):
32
- gened = base_model.generate(
33
- **tokenizer(
34
- f"{x}",
35
- return_tensors='pt',
36
- return_token_type_ids=False
37
- ),
38
- #bad_words_ids = bad_words_ids ,
39
- max_new_tokens=max_new_tokens,
40
- min_new_tokens = 5,
41
- exponential_decay_length_penalty = (max_new_tokens/2, 1.1),
42
- top_p=top_p,
43
- top_k=top_k,
44
- temperature = temperature,
45
- early_stopping=True,
46
- do_sample=True,
47
- eos_token_id=2,
48
- pad_token_id=2,
49
- #stopping_criteria = stopping_criteria,
50
- repetition_penalty=repetition_penalty,
51
- no_repeat_ngram_size = 2
52
- )
53
 
54
- model_output = tokenizer.decode(gened[0])
55
- return model_output
 
 
56
 
57
  def reset_textbox():
58
  return gr.update(value='')
59
 
60
-
61
  with gr.Blocks() as demo:
62
  duplicate_link = "https://huggingface.co/spaces/beomi/KoRWKV-1.5B?duplicate=true"
63
  gr.Markdown(
@@ -67,7 +29,7 @@ with gr.Blocks() as demo:
67
  with gr.Row():
68
  with gr.Column(scale=4):
69
  user_text = gr.Textbox(
70
- placeholder='\\nfriend: 우리 여행 갈래? \\nyou:',
71
  label="User input"
72
  )
73
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
@@ -76,18 +38,6 @@ with gr.Blocks() as demo:
76
  max_new_tokens = gr.Slider(
77
  minimum=1, maximum=200, value=20, step=1, interactive=True, label="Max New Tokens",
78
  )
79
- top_p = gr.Slider(
80
- minimum=0.05, maximum=1.0, value=0.8, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
81
- )
82
- top_k = gr.Slider(
83
- minimum=5, maximum=100, value=30, step=5, interactive=True, label="Top-k (nucleus sampling)",
84
- )
85
- temperature = gr.Slider(
86
- minimum=0.1, maximum=2.0, value=0.5, step=0.1, interactive=True, label="Temperature",
87
- )
88
- repetition_penalty = gr.Slider(
89
- minimum=1.0, maximum=3.0, value=1.2, step=0.1, interactive=True, label="repetition_penalty",
90
- )
91
 
92
- button_submit.click(gen, [user_text, top_p, top_k, temperature, max_new_tokens, repetition_penalty], model_output)
93
  demo.queue(max_size=32).launch(enable_queue=True)
 
1
  from threading import Thread
2
+ from llama_cpp import Llama
3
  import torch
4
  import gradio as gr
5
  import re
 
8
  print("Running on device:", torch_device)
9
  print("CPU threads:", torch.get_num_threads())
10
 
11
+ llm = Llama(model_path = 'Llama-2-ko-7B-chat-gguf-q4_0.bin',
12
+ n_ctx=1024,
13
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ def gen(x, max_new_tokens):
16
+ output = llm(f"Q: {x} A: ", max_tokens=1024, stop=["Q:", "\n"], echo=True)
17
+
18
+ return output['choices'][0]['text'].replace('▁',' ')
19
 
20
  def reset_textbox():
21
  return gr.update(value='')
22
 
 
23
  with gr.Blocks() as demo:
24
  duplicate_link = "https://huggingface.co/spaces/beomi/KoRWKV-1.5B?duplicate=true"
25
  gr.Markdown(
 
29
  with gr.Row():
30
  with gr.Column(scale=4):
31
  user_text = gr.Textbox(
32
+ placeholder='우리 여행 갈래?',
33
  label="User input"
34
  )
35
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
 
38
  max_new_tokens = gr.Slider(
39
  minimum=1, maximum=200, value=20, step=1, interactive=True, label="Max New Tokens",
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ button_submit.click(gen, [user_text, max_new_tokens], model_output)
43
  demo.queue(max_size=32).launch(enable_queue=True)