File size: 4,848 Bytes
9ca25de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from threading import Thread

import torch
import gradio as gr
from transformers import pipeline,AutoTokenizer, AutoModelForCausalLM, BertTokenizer, BertForSequenceClassification, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel, PeftConfig
import re
from kobert_transformers import get_tokenizer

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())

peft_model_id = "ldhldh/polyglot-ko-1.3b_lora_big_8kstep" 
#18k > 상대의 말까지 하는 이슈가 있음 
#8k > 약간 아쉬운가? 
config = PeftConfig.from_pretrained(peft_model_id)

base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

#base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/polyglot-ko-3.8b")
#tokenizer = AutoTokenizer.from_pretrained("EleutherAI/polyglot-ko-3.8b")
base_model.eval()
#base_model.config.use_cache = True


model = PeftModel.from_pretrained(base_model, peft_model_id, device_map="auto")
model.eval()
#model.config.use_cache = True


mbti_bert_model_name = "Lanvizu/fine-tuned-klue-bert-base_model_11"
mbti_bert_model = BertForSequenceClassification.from_pretrained(mbti_bert_model_name)
mbti_bert_model.eval()
mbti_bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

bert_model_name = "ldhldh/bert_YN_small"
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name)
bert_model.eval()
bert_tokenizer = get_tokenizer()


def mbti_classify(x):
    classifier = pipeline("text-classification", model=mbti_bert_model, tokenizer=mbti_bert_tokenizer, return_all_scores=True)
    result = classifier([x])
    return result[0]


def classify(x):
    input_list =  bert_tokenizer.batch_encode_plus([x], truncation=True, padding=True, return_tensors='pt')
    input_ids = input_list['input_ids'].to(bert_model.device)
    attention_masks = input_list['attention_mask'].to(bert_model.device)
    outputs =  bert_model(input_ids, attention_mask=attention_masks, return_dict=True)
    return outputs.logits.argmax(dim=1).cpu().tolist()[0]

def gen(x, top_p, top_k, temperature, max_new_tokens, repetition_penalty):
    gened = model.generate(
        **tokenizer(
            f"{x}",
            return_tensors='pt',
            return_token_type_ids=False
        ),
        #bad_words_ids = bad_words_ids ,
        max_new_tokens=max_new_tokens,
        min_new_tokens = 5,
        exponential_decay_length_penalty = (max_new_tokens/2, 1.1),
        top_p=top_p,
        top_k=top_k,
        temperature = temperature,
        early_stopping=True,
        do_sample=True,
        eos_token_id=2,
        pad_token_id=2,
        #stopping_criteria = stopping_criteria,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size = 2
    )

    model_output = tokenizer.decode(gened[0])
    return model_output

def reset_textbox():
    return gr.update(value='')


with gr.Blocks() as demo:
    duplicate_link = "https://huggingface.co/spaces/beomi/KoRWKV-1.5B?duplicate=true"
    gr.Markdown(
       "duplicated from beomi/KoRWKV-1.5B, baseModel:EleutherAI/polyglot-ko-1.3b"
    )
    
    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                placeholder='\\nfriend: 우리 여행 갈래? \\nyou:',
                label="User input"
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")
            button_bert = gr.Button(value="bert_Sumit")
            button_mbti_bert = gr.Button(value="mbti_bert_Sumit")
        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=1, maximum=200, value=20, step=1, interactive=True, label="Max New Tokens",
            )
            top_p = gr.Slider(
                minimum=0.05, maximum=1.0, value=0.8, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
            )
            top_k = gr.Slider(
                minimum=5, maximum=100, value=30, step=5, interactive=True, label="Top-k (nucleus sampling)",
            )
            temperature = gr.Slider(
                minimum=0.1, maximum=2.0, value=0.5, step=0.1, interactive=True, label="Temperature",
            )
            repetition_penalty = gr.Slider(
                minimum=1.0, maximum=3.0, value=1.2, step=0.1, interactive=True, label="repetition_penalty",
            )
    
    button_submit.click(gen, [user_text, top_p, top_k, temperature, max_new_tokens, repetition_penalty], model_output)
    button_bert.click(classify, [user_text], model_output)
    button_mbti_bert.click(mbti_classify, [user_text], model_output)
    demo.queue(max_size=32).launch(enable_queue=True)