Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Meena_A_Multilingual_Chatbot (1).ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1-IfUcnDUppyMArHonc_iesEcN2gSKU-j | |
""" | |
#!pip3 install transformers | |
#!pip install -q translate | |
#!pip install polyglot | |
#!pip install Pyicu | |
#!pip install Morfessor | |
#!pip install pycld2 | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from translate import Translator | |
from polyglot.detect import Detector | |
# model_name = "microsoft/DialoGPT-large" | |
model_name = "microsoft/DialoGPT-large" | |
# model_name = "microsoft/DialoGPT-small" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# # chatting 5 times with nucleus sampling & tweaking temperature | |
# step=-1 | |
# while(True): | |
# step+=1 | |
# # take user input | |
# text = input(">> You:>") | |
# detected_language=Detector(text,quiet=True).language.code | |
# translator=Translator(from_lang=detected_language,to_lang="en") | |
# translated_input=translator.translate(text) | |
# print(translated_input) | |
# if text.lower().find("bye")!=-1: | |
# print(f">> Meena:> Bye Bye!") | |
# break; | |
# # encode the input and add end of string token | |
# input_ids = tokenizer.encode(translated_input+tokenizer.eos_token, return_tensors="pt") | |
# # concatenate new user input with chat history (if there is) | |
# bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids | |
# # generate a bot response | |
# chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,do_sample=True,top_p=0.9,top_k=50,temperature=0.7,num_beams=5,no_repeat_ngram_size=2) | |
# #print the output | |
# output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# print(output) | |
# translator=Translator(from_lang="en",to_lang=detected_language) | |
# translated_output=translator.translate(output) | |
# print(f">> Meena:> {translated_output}") | |
#!pip install gradio | |
import gradio as gr | |
def generate_text(text): | |
step=-1 | |
while(True): | |
step+=1 | |
detected_language=Detector(text,quiet=True).language.code | |
translator=Translator(from_lang=detected_language,to_lang="en") | |
translated_input=translator.translate(text) | |
if text.lower().find("bye")!=-1: | |
print(f">> Meena:> Bye Bye!") | |
break; | |
# encode the input and add end of string token | |
input_ids = tokenizer.encode(translated_input+tokenizer.eos_token, return_tensors="pt") | |
# concatenate new user input with chat history (if there is) | |
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids | |
# generate a bot response | |
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,do_sample=True,top_p=0.9,top_k=50,temperature=0.7,num_beams=5,no_repeat_ngram_size=2) | |
#print the output | |
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
translator=Translator(from_lang="en",to_lang=detected_language) | |
translated_output=translator.translate(output) | |
return translated_output | |
output_text=gr.Textbox() | |
gr.Interface(generate_text,"textbox",output_text,title="Meena", | |
description="Meena- A Multilingual Chatbot").launch(debug=False) | |
#!gradio deploy |