import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import json
import urllib, urllib.request
from diffusers import DiffusionPipeline
from transformers import pipeline

# First Page
def demo_tab(image):
    # The demo tab simply returns the same image as input
    return image

# Secibd Page
def generate_image(Prompt,Negative_prompt,Steps):
    # load both base & refiner
    base = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
    )
    base.to("cuda")
    refiner = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-refiner-1.0",
        text_encoder_2=base.text_encoder_2,
        vae=base.vae,
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    )
    refiner.to("cuda")

    # Define how many steps and what % of steps to be run on each experts (80/20) here
    high_noise_frac = 0.8

    prompt = Prompt
    negative_prompt = Negative_prompt
    n_steps = Steps

    # run both experts
    image = base(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
    ).images
    image = refiner(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
    ).images[0]

    return image

def predict(input_image):
    model = torch.hub.load('RF5/danbooru-pretrained', 'resnet50')
    model.eval()

    # Load JSON file from github as Label
    with urllib.request.urlopen("https://github.com/RF5/danbooru-pretrained/raw/master/config/class_names_6000.json") as url:
        labels = json.loads(url.read().decode())

    #Convert input image from array to PIL Image
    input_image = Image.fromarray(input_image.astype('uint8'), 'RGB')

    #Preprocess the input image
    preprocess = transforms.Compose([
        transforms.Resize(360),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    # Use CUDA if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    # Make prediction
    with torch.no_grad():
        output = model(input_batch)

    # Get probabilities
    probs = torch.sigmoid(output[0])

    # Convert tensor to Python list of floats
    probs = probs.cpu().numpy().tolist()

    # Sort labels with probabilities and return top 10
    sorted_labels_with_probs = sorted(list(zip(labels, probs)), key=lambda x: x[1], reverse=True)[:10]

    # Convert list of tuples to dictionary and convert numpy floats to Python floats
    sorted_labels_with_probs_dict = {label: float(prob) for label, prob in sorted_labels_with_probs}

    return sorted_labels_with_probs_dict

def image_classify(input_image, model):
    model_mapping= {
        "Resnet 50": "microsoft/resnet-50",
        "Vit Base Patch16-224": "google/vit-base-patch16-224",
        "NSFW Image Detection": "Falconsai/nsfw_image_detection",
        "Vit Age Classifier": "nateraw/vit-age-classifier"
    }

    classifier = pipeline("image-classification", model=model_mapping[model])
    img = input_image
    result = classifier(img)

    #Sort the perccentage confident from highest to lowest
    highest_confidence_result = sorted(result, key=lambda x: x['score'], reverse=True)[0]

    # Format the score as a percentage and combine it with the label
    output = f"{highest_confidence_result['score']*100:.2f}% confident : {highest_confidence_result['label']}"

    return output

# Define the demo tab
with gr.Blocks() as demo:
    with gr.Tab("Demo"):
        image_input = gr.Image(type='pil')
        image_output = gr.Image()
        demo_button = gr.Button("Generate")

    with gr.Tab("Text2Image"):
        SD_text_input = gr.Textbox(lines=5, label="Prompt")
        SD_text2_input = gr.Textbox(lines=5, label="Negative Prompt")
        Slider_input = gr.Slider(0, 100, label="Strength")
        SD_output = gr.Image()
        SD_button = gr.Button("Generate")

    with gr.Tab("Text2Image"):
        option_input = gr.Dropdown(["resnet50", "vit-base-patch16-224", "vit-age-classifier", "nsfw image classification"], label="Model")
        t2i_input = gr.Image(label="Image")
        t2i_output = gr.Image()
        t2i_button = gr.Button("Generate")

    demo_button.click(demo_tab, inputs=image_input, outputs=image_output)
    SD_button.click(generate_image, inputs=[SD_text_input,SD_text2_input,Slider_input], outputs=SD_output)
    t2i_button.click(image_classify, inputs=[option_input,t2i_input], outputs=t2i_button)

demo.launch(debug=True)