choi4th4570 commited on
Commit
1207ce4
1 Parent(s): 266f469

Upload final - 복사본.py

Browse files
Files changed (1) hide show
  1. final - 복사본.py +85 -0
final - 복사본.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
2
+ import transformers
3
+ import torch
4
+ import numpy as np
5
+ from torch.optim import lr_scheduler
6
+ import torch.optim as optim
7
+ from accelerate import Accelerator
8
+ import os
9
+ import json
10
+ import jsonlines
11
+ import pandas as pd
12
+ import gradio as gr
13
+
14
+
15
+ from torch.utils.data import Dataset as Dataset2
16
+
17
+ from datasets import load_dataset, Dataset, Features, Value
18
+
19
+ from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel, PeftConfig
20
+
21
+ title = "🤖응답하라 챗봇🤖"
22
+ description = "특정 도메인에 특화된 챗봇"
23
+ examples = [["특허란 무엇인가요?"], ["자동차에 관한 특허는 무엇이 있나요?"], ]
24
+
25
+ def gen(input, history):
26
+ gened = model.generate(
27
+ **tokenizer(
28
+ f"##질문:{input}\n##답변:",
29
+ return_tensors='pt',
30
+ return_token_type_ids=False
31
+ ).to('cuda'),
32
+ max_new_tokens=256,
33
+ no_repeat_ngram_size=6,
34
+ # top_p=0.8,
35
+ temperature=0.7,
36
+ early_stopping=True,
37
+ # num_return_sequences=5,
38
+ do_sample=True,
39
+ eos_token_id=2,
40
+ pad_token_id=2
41
+ )
42
+ # print(gened[0])
43
+ return tokenizer.decode(gened[0])
44
+
45
+
46
+ model_root = "polyglot-ko-1.3b-lora-knk"
47
+ model_name = 'polyglot-ko-1.3b-base-knk'
48
+
49
+ # config = LoraConfig(
50
+ # r=8,
51
+ # lora_alpha=32, # 32
52
+ # target_modules=["query_key_value"],
53
+ # lora_dropout=0.05,
54
+ # bias="none",
55
+ # task_type="CAUSAL_LM"
56
+ # )
57
+
58
+ config = PeftConfig.from_pretrained(model_root)
59
+ tokenizer = AutoTokenizer.from_pretrained('EleutherAI/polyglot-ko-1.3b')
60
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map={"":0})
61
+ # model.eval()
62
+ model = PeftModel.from_pretrained(model, model_root)
63
+
64
+ # model = prepare_model_for_kbit_training(model)
65
+ # model = get_peft_model(model, config)
66
+
67
+ model.eval()
68
+ model.config.use_cache = True # silence the warnings. Please re-enable for inference!
69
+
70
+ """
71
+ while 1:
72
+ ques = input('질문:')
73
+ if ques == '1': break
74
+ gen(ques)
75
+ """
76
+ demo = gr.ChatInterface(
77
+ fn=gen,
78
+ title=title,
79
+ description=description,
80
+ examples=examples,
81
+ theme="xiaobaiyuan/theme_brief",
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch(share=True)