chat-gradio / personalized-chat-bot /scripts /fit_personality_clustering.py
j.gilyazev
add personalized-chat-bot
0766044
raw history blame
No virus
1.91 kB
import argparse
from datasets import load_dataset
from models.personality_clustering import PersonalityClustering
import os
"""Пример запуска
python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500
"""
PERSONACHAT_DATASET = "bavard/personachat_truecased"
def load_persona_chat_personalities(personachat_dataset):
dataset = load_dataset(personachat_dataset)
train_personalities = [sent for persona in dataset['train']['personality']
for sent in persona]
test_personalities = [sent for persona in dataset['train']['personality']
for sent in persona]
personalities = list(set(train_personalities) | set(test_personalities))
return personalities
def parse_args(args=None):
parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.")
parser.add_argument('-clustering-path', '--clustering-path', type=str,
help='Path to clustering data.')
parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500,
help='The number of clusters to form.')
parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False)
args = parser.parse_args(args)
return args
def main():
args = parse_args()
personalities = load_persona_chat_personalities(PERSONACHAT_DATASET)
print('Data loaded')
model = PersonalityClustering(n_clusters=args.n_clusters)
print('Model fitting')
model.fit(personalities)
print('Model fitted')
if args.model_name is None:
model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl'
else:
model_name = args.model_name
model.save(os.path.join(args.clustering_path, model_name))
print(f'{model_name} saved')
if __name__ == '__main__':
main()