import numpy as np import torch import urllib.request import urllib.parse import json import urllib import matplotlib.pyplot as plt from transformers import AutoModelWithLMHead, AutoTokenizer from classifier import Classifier tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model = AutoModelWithLMHead.from_pretrained("microsoft/DialoGPT-medium") conversation = [] last_10_conv = ["chitchat","chitchat","chitchat","chitchat","chitchat","chitchat","chitchat", "chitchat","chitchat","chitchat"] topicwise_count = {"politics": 0, "environment": 0, "technology":0, "healthcare":0, "education": 0 , "chitchat" : 0} topics = {"politics": 0, "environment": 1, "technology":2, "healthcare":3, "education": 4 , "chitchat" : 5} def classify(input): c = Classifier() return c.classify(input) def get_chitchat_response(input, history=[]): # tokenize the new input sentence new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') # append the new user input tokens to the chat history bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) # generate a response history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, temperature=0.6, repetition_penalty=1.3).tolist() # convert the tokens to text, and then split the responses into lines response = tokenizer.decode(history[0]).split("<|endoftext|>") response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] return response, history def get_result_from_solr(topic, input ): response = 'Sorry, I don\'t know about that' search_fields = "selftext title" if( topic == '' or topic == 'Generic'): dismaxinurl = 'http://34.132.64.242:8080/solr/IRF22P1/select?fl=id%2Cscore%2Cselftext%2Csubreddit%2Ctopic%2Ctitle&defType=dismax&indent=true&q.op=OR&q=' + urllib.parse.quote(input) + '&qf=' + urllib.parse.quote( search_fields) + '&rows=10' else: topic = topic.lower() dismaxinurl = 'http://34.132.64.242:8080/solr/IRF22P1/select?facet.field=subreddit&facet.query='+topic+'&facet=true&fl=id%2Cscore%2Cselftext%2Csubreddit%2Ctopic%2Ctitle&defType=dismax&indent=true&q.op=OR&q=' + urllib.parse.quote( input) + '&qf=' + urllib.parse.quote( search_fields) + '&rows=10' data = urllib.request.urlopen(dismaxinurl) data = json.load(data) docs = data['response']['docs'] if(len(docs) != 0): response = docs[0]['selftext'] topic = docs[0]['subreddit'] return response, topic def make_autopct(values): def my_autopct(pct): total = sum(values) val = int(round(pct*total/100.0)) return '{p:.2f}% ({v:d})'.format(p=pct,v=val) return my_autopct def get_pie_chart(): plt.clf() counts = topicwise_count.values(); labels = topicwise_count.keys(); plt.pie(np.array(list(counts)), labels=list(labels),autopct=make_autopct(np.array(list(counts))), shadow=True, wedgeprops={'edgecolor':'black'}) plt.title("Topic Distribution among queries") return plt def get_last_n_plot(): queries = ['Qn', 'Qn-1', 'Qn-2', 'Qn-3', 'Qn-4', 'Qn-5', 'Qn-6', 'Qn-7', 'Qn-8', 'Qn-9'] queries.reverse() plt.plot(queries, last_10_conv) return plt def chatbot(topic, input, history=[] ): if (len(last_10_conv) == 10): last_10_conv.pop(0) is_chitchat = classify(input) if(is_chitchat ): response, history = get_chitchat_response(input, history) conversation.append(response[len(response) - 1]) #for visualization last_10_conv.append("chitchat") topicwise_count["chitchat"] = topicwise_count["chitchat"] + 1 else: solr_res, topic = get_result_from_solr(topic, input ) conversation.append((input, solr_res)) # for visualization if(topic in topics.keys()): last_10_conv.append(topic) topicwise_count[topic] = topicwise_count[topic] + 1 else: last_10_conv.append("Other Topic") pie_chart = get_pie_chart() return conversation, history, pie_chart