--- language: - ko tags: - generated_from_keras_callback model-index: - name: t5-base-korean-chit-chat results: [] --- # t5-base-korean-chit-chat This model is a fine-tuning of paust/pko-t5-base model using AIHUB "한국어 SNS". This model infers the next conversation by using the conversation used on social media.. 이 모델은 paust/pko-t5-large model을 AIHUB "한국어 SNS"를 이용하여 fine tunning 한 것입니다. 이 모델은 SNS상에서 사용되는 대화를 이용하여 다음 대화를 추론 합니다. ## Usage ```python from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MT5ForConditionalGeneration from transformers import AutoTokenizer, T5TokenizerFast import nltk nltk.download('punkt') model_dir = f"lcw99/t5-base-korean-chit-chat" max_input_length = 1024 text = """ A: 쇼핑하러 갈까? B: 응 좋아. A: 언제 갈까? B: """ inputs = [text] tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt") output = model.generate(**inputs, num_beams=3, do_sample=True, min_length=20, max_length=500, num_return_sequences=3) for i in range(3): #print(output[i]) print("---", i) decoded_output = tokenizer.decode(output[i], skip_special_tokens=True) predicted_title = nltk.sent_tokenize(decoded_output) #print(decoded_output) print(predicted_title) import torch chat_history = [] # Let's chat for 5 lines for step in range(100): print("") user_input = input(">> User: ") chat_history.append("A: " + user_input) while len(chat_history) > 5: chat_history.pop(0) hist = "" for chat in chat_history: hist += "\n" + chat hist += "\nB: " new_user_input_ids = tokenizer.encode(hist, return_tensors='pt') bot_input_ids = new_user_input_ids # generated a response while limiting the total chat history to 1000 tokens, chat_history_ids = model.generate( bot_input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id, do_sample=True, #top_k=100, #top_p=0.7, #temperature = 0.1 ) bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@이름#", "OOO") bot_text = bot_text.replace("\n", " / ") chat_history.append("B: " + bot_text) # pretty print last ouput tokens from bot print("Bot: {}".format(bot_text)) ``` ### Framework versions - Transformers 4.22.1 - TensorFlow 2.10.0 - Datasets 2.5.1 - Tokenizers 0.12.1