Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import urllib.request | |
| import urllib.parse | |
| import json | |
| import urllib | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoModelWithLMHead, AutoTokenizer | |
| from classifier import Classifier | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
| model = AutoModelWithLMHead.from_pretrained("microsoft/DialoGPT-medium") | |
| conversation = [] | |
| last_10_conv = ["chitchat","chitchat","chitchat","chitchat","chitchat","chitchat","chitchat", "chitchat","chitchat","chitchat"] | |
| topicwise_count = {"politics": 0, "environment": 0, "technology":0, "healthcare":0, "education": 0 , "chitchat" : 0} | |
| topics = {"politics": 0, "environment": 1, "technology":2, "healthcare":3, "education": 4 , "chitchat" : 5} | |
| def classify(input): | |
| c = Classifier() | |
| return c.classify(input) | |
| def get_chitchat_response(input, history=[]): | |
| # tokenize the new input sentence | |
| new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') | |
| # append the new user input tokens to the chat history | |
| bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
| # generate a response | |
| history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, temperature=0.6, | |
| repetition_penalty=1.3).tolist() | |
| # convert the tokens to text, and then split the responses into lines | |
| response = tokenizer.decode(history[0]).split("<|endoftext|>") | |
| response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] | |
| return response, history | |
| def get_result_from_solr(topic, input ): | |
| response = 'Sorry, I don\'t know about that' | |
| search_fields = "selftext title" | |
| if( topic == '' or topic == 'Generic'): | |
| dismaxinurl = 'http://34.132.64.242:8080/solr/IRF22P1/select?fl=id%2Cscore%2Cselftext%2Csubreddit%2Ctopic%2Ctitle&defType=dismax&indent=true&q.op=OR&q=' + urllib.parse.quote(input) + '&qf=' + urllib.parse.quote( | |
| search_fields) + '&rows=10' | |
| else: | |
| topic = topic.lower() | |
| dismaxinurl = 'http://34.132.64.242:8080/solr/IRF22P1/select?facet.field=subreddit&facet.query='+topic+'&facet=true&fl=id%2Cscore%2Cselftext%2Csubreddit%2Ctopic%2Ctitle&defType=dismax&indent=true&q.op=OR&q=' + urllib.parse.quote( | |
| input) + '&qf=' + urllib.parse.quote( | |
| search_fields) + '&rows=10' | |
| data = urllib.request.urlopen(dismaxinurl) | |
| data = json.load(data) | |
| docs = data['response']['docs'] | |
| if(len(docs) != 0): | |
| response = docs[0]['selftext'] | |
| topic = docs[0]['subreddit'] | |
| return response, topic | |
| def make_autopct(values): | |
| def my_autopct(pct): | |
| total = sum(values) | |
| val = int(round(pct*total/100.0)) | |
| return '{p:.2f}% ({v:d})'.format(p=pct,v=val) | |
| return my_autopct | |
| def get_pie_chart(): | |
| plt.clf() | |
| counts = topicwise_count.values(); | |
| labels = topicwise_count.keys(); | |
| plt.pie(np.array(list(counts)), labels=list(labels),autopct=make_autopct(np.array(list(counts))), shadow=True, wedgeprops={'edgecolor':'black'}) | |
| plt.title("Topic Distribution among queries") | |
| return plt | |
| def get_last_n_plot(): | |
| queries = ['Qn', 'Qn-1', 'Qn-2', 'Qn-3', 'Qn-4', 'Qn-5', 'Qn-6', 'Qn-7', 'Qn-8', 'Qn-9'] | |
| queries.reverse() | |
| plt.plot(queries, last_10_conv) | |
| return plt | |
| def chatbot(topic, input, history=[] ): | |
| if (len(last_10_conv) == 10): | |
| last_10_conv.pop(0) | |
| is_chitchat = classify(input) | |
| if(is_chitchat ): | |
| response, history = get_chitchat_response(input, history) | |
| conversation.append(response[len(response) - 1]) | |
| #for visualization | |
| last_10_conv.append("chitchat") | |
| topicwise_count["chitchat"] = topicwise_count["chitchat"] + 1 | |
| else: | |
| solr_res, topic = get_result_from_solr(topic, input ) | |
| conversation.append((input, solr_res)) | |
| # for visualization | |
| if(topic in topics.keys()): | |
| last_10_conv.append(topic) | |
| topicwise_count[topic] = topicwise_count[topic] + 1 | |
| else: | |
| last_10_conv.append("Other Topic") | |
| pie_chart = get_pie_chart() | |
| return conversation, history, pie_chart | |