from flask import Flask, request
import requests
import os
import re
import textwrap
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from bart import BartForConditionalGeneration
from langdetect import detect
import subprocess

tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-re-attention-seq-512")

vn_tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-vn-re-attention-vn-tokenizer")

model = BartForConditionalGeneration.from_pretrained(
    "GuysTrans/bart-base-re-attention-seq-512")

vn_model = BartForConditionalGeneration.from_pretrained(
    "GuysTrans/bart-base-vn-re-attention-vn-tokenizer")

map_words = {
    "Hello and Welcome to 'Ask A Doctor' service": "",
    "Hello,": "",
    "Hi,": "",
    "Hello": "",
    "Hi": "",
    "Ask A Doctor": "MedForum",
    "H C M": "Med Forum"
}

word_remove_sentence = [
    "Welcome to",
    # "hello",
    # "hi",
    # "regards",
    # "dr.",
    # "physician",
    # "welcome",
]


def generate_summary(question, model, tokenizer):
    inputs = tokenizer(
        question,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(
        input_ids, attention_mask=attention_mask, max_new_tokens=4096, do_sample=True, num_beams=4, top_k=50, early_stopping=True, no_repeat_ngram_size=2)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


app = Flask(__name__)

FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages'
VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw='
# paste your page access token here>"
PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN']


def get_bot_response(message):
    lang = detect(message)
    model_use = model
    tokenizer_use = tokenizer
    template = "Welcome to MedForRum chatbot service. %s. Thanks for asking on MedForum."
    if lang == "vi":
        model_use = vn_model
        tokenizer_use = vn_tokenizer
        template = "Chào mừng bạn đến với dịch vụ MedForRum chatbot. %s. Cảm ơn bạn đã sử dụng MedForum."
    return template % post_process(generate_summary(message, model_use, tokenizer_use)[1][0])


def verify_webhook(req):
    if req.args.get("hub.verify_token") == VERIFY_TOKEN:
        return req.args.get("hub.challenge")
    else:
        return "incorrect"


def respond(sender, message):
    """Formulate a response to the user and
    pass it on to a function that sends it."""
    response = get_bot_response(message)
    send_message(sender, response)
    return response


def is_user_message(message):
    """Check if the message is a message from the user"""
    return (message.get('message') and
            message['message'].get('text') and
            not message['message'].get("is_echo"))


@app.route("/webhook", methods=['GET', 'POST'])
def listen():
    """This is the main function flask uses to
    listen at the `/webhook` endpoint"""
    if request.method == 'GET':
        return verify_webhook(request)

    if request.method == 'POST':
        payload = request.json
        event = payload['entry'][0]['messaging']
        for x in event:
            if is_user_message(x):
                text = x['message']['text']
                sender_id = x['sender']['id']
                respond(sender_id, text)

        return "ok"


def send_message(recipient_id, text):
    """Send a response to Facebook"""
    payload = {
        'message': {
            'text': text
        },
        'recipient': {
            'id': recipient_id
        },
        'notification_type': 'regular'
    }

    auth = {
        'access_token': PAGE_ACCESS_TOKEN
    }

    response = requests.post(
        FB_API_URL,
        params=auth,
        json=payload
    )

    return response.json()


@app.route("/webhook/chat", methods=['POST'])
def chat():
    payload = request.json
    message = payload['message']
    response = get_bot_response(message)
    return {"message": response}

def post_process(output):
    
    # output = textwrap.fill(textwrap.dedent(output).strip(), width=120)
    lines = output.split(".")
    for line in lines:
        for word in word_remove_sentence:
            if word.lower() in line.lower():
                lines.remove(line)
                break
    
    output = ".".join(lines)
    for item in map_words.keys():
        output = re.sub(item, map_words[item], output, re.I)
    
    return textwrap.fill(textwrap.dedent(output).strip(), width=120)
    


subprocess.Popen(["autossh", "-M", "0", "-tt", "-o", "StrictHostKeyChecking=no",
                 "-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"])
# subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True)