ldhldh commited on
Commit
9ca25de
0 Parent(s):

Duplicate from ldhldh/polyglot_ko_1.3B_PEFT_demo

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +123 -0
  4. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 🤗 KoRWKV-1.5B 🔥Streaming🔥
3
+ emoji: 💻
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: ldhldh/polyglot_ko_1.3B_PEFT_demo
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import pipeline,AutoTokenizer, AutoModelForCausalLM, BertTokenizer, BertForSequenceClassification, StoppingCriteria, StoppingCriteriaList
6
+ from peft import PeftModel, PeftConfig
7
+ import re
8
+ from kobert_transformers import get_tokenizer
9
+
10
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print("Running on device:", torch_device)
12
+ print("CPU threads:", torch.get_num_threads())
13
+
14
+ peft_model_id = "ldhldh/polyglot-ko-1.3b_lora_big_8kstep"
15
+ #18k > 상대의 말까지 하는 이슈가 있음
16
+ #8k > 약간 아쉬운가?
17
+ config = PeftConfig.from_pretrained(peft_model_id)
18
+
19
+ base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
20
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
21
+
22
+ #base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/polyglot-ko-3.8b")
23
+ #tokenizer = AutoTokenizer.from_pretrained("EleutherAI/polyglot-ko-3.8b")
24
+ base_model.eval()
25
+ #base_model.config.use_cache = True
26
+
27
+
28
+ model = PeftModel.from_pretrained(base_model, peft_model_id, device_map="auto")
29
+ model.eval()
30
+ #model.config.use_cache = True
31
+
32
+
33
+ mbti_bert_model_name = "Lanvizu/fine-tuned-klue-bert-base_model_11"
34
+ mbti_bert_model = BertForSequenceClassification.from_pretrained(mbti_bert_model_name)
35
+ mbti_bert_model.eval()
36
+ mbti_bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
37
+
38
+ bert_model_name = "ldhldh/bert_YN_small"
39
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name)
40
+ bert_model.eval()
41
+ bert_tokenizer = get_tokenizer()
42
+
43
+
44
+ def mbti_classify(x):
45
+ classifier = pipeline("text-classification", model=mbti_bert_model, tokenizer=mbti_bert_tokenizer, return_all_scores=True)
46
+ result = classifier([x])
47
+ return result[0]
48
+
49
+
50
+ def classify(x):
51
+ input_list = bert_tokenizer.batch_encode_plus([x], truncation=True, padding=True, return_tensors='pt')
52
+ input_ids = input_list['input_ids'].to(bert_model.device)
53
+ attention_masks = input_list['attention_mask'].to(bert_model.device)
54
+ outputs = bert_model(input_ids, attention_mask=attention_masks, return_dict=True)
55
+ return outputs.logits.argmax(dim=1).cpu().tolist()[0]
56
+
57
+ def gen(x, top_p, top_k, temperature, max_new_tokens, repetition_penalty):
58
+ gened = model.generate(
59
+ **tokenizer(
60
+ f"{x}",
61
+ return_tensors='pt',
62
+ return_token_type_ids=False
63
+ ),
64
+ #bad_words_ids = bad_words_ids ,
65
+ max_new_tokens=max_new_tokens,
66
+ min_new_tokens = 5,
67
+ exponential_decay_length_penalty = (max_new_tokens/2, 1.1),
68
+ top_p=top_p,
69
+ top_k=top_k,
70
+ temperature = temperature,
71
+ early_stopping=True,
72
+ do_sample=True,
73
+ eos_token_id=2,
74
+ pad_token_id=2,
75
+ #stopping_criteria = stopping_criteria,
76
+ repetition_penalty=repetition_penalty,
77
+ no_repeat_ngram_size = 2
78
+ )
79
+
80
+ model_output = tokenizer.decode(gened[0])
81
+ return model_output
82
+
83
+ def reset_textbox():
84
+ return gr.update(value='')
85
+
86
+
87
+ with gr.Blocks() as demo:
88
+ duplicate_link = "https://huggingface.co/spaces/beomi/KoRWKV-1.5B?duplicate=true"
89
+ gr.Markdown(
90
+ "duplicated from beomi/KoRWKV-1.5B, baseModel:EleutherAI/polyglot-ko-1.3b"
91
+ )
92
+
93
+ with gr.Row():
94
+ with gr.Column(scale=4):
95
+ user_text = gr.Textbox(
96
+ placeholder='\\nfriend: 우리 여행 갈래? \\nyou:',
97
+ label="User input"
98
+ )
99
+ model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
100
+ button_submit = gr.Button(value="Submit")
101
+ button_bert = gr.Button(value="bert_Sumit")
102
+ button_mbti_bert = gr.Button(value="mbti_bert_Sumit")
103
+ with gr.Column(scale=1):
104
+ max_new_tokens = gr.Slider(
105
+ minimum=1, maximum=200, value=20, step=1, interactive=True, label="Max New Tokens",
106
+ )
107
+ top_p = gr.Slider(
108
+ minimum=0.05, maximum=1.0, value=0.8, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
109
+ )
110
+ top_k = gr.Slider(
111
+ minimum=5, maximum=100, value=30, step=5, interactive=True, label="Top-k (nucleus sampling)",
112
+ )
113
+ temperature = gr.Slider(
114
+ minimum=0.1, maximum=2.0, value=0.5, step=0.1, interactive=True, label="Temperature",
115
+ )
116
+ repetition_penalty = gr.Slider(
117
+ minimum=1.0, maximum=3.0, value=1.2, step=0.1, interactive=True, label="repetition_penalty",
118
+ )
119
+
120
+ button_submit.click(gen, [user_text, top_p, top_k, temperature, max_new_tokens, repetition_penalty], model_output)
121
+ button_bert.click(classify, [user_text], model_output)
122
+ button_mbti_bert.click(mbti_classify, [user_text], model_output)
123
+ demo.queue(max_size=32).launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ git+https://github.com/huggingface/peft
3
+ torch
4
+ accelerate
5
+ xformers
6
+ kobert-transformers
7
+ gradio