import os
import redis
import pickle
import torch
from PIL import Image
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    FluxPipeline,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
)
from diffusers.utils import export_to_video
from transformers import (
    pipeline as transformers_pipeline,
    AutoModelForCausalLM,
    AutoTokenizer,
    GPT2Tokenizer,
    GPT2Model,
)
from audiocraft.models import musicgen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import multiprocessing
import io
import time
from tqdm import tqdm
from google.cloud import storage
import json

hf_token = os.getenv("HF_TOKEN")
redis_host = os.getenv("REDIS_HOST")
redis_port = int(os.getenv("REDIS_PORT", 6379))
redis_password = os.getenv("REDIS_PASSWORD")
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")

HfFolder.save_token(hf_token)

storage_client = storage.Client.from_service_account_info(gcs_credentials)


def connect_to_redis():
    while True:
        try:
            redis_client = redis.Redis(
                host=redis_host, port=redis_port, password=redis_password
            )
            redis_client.ping()
            return redis_client
        except (
            redis.exceptions.ConnectionError,
            redis.exceptions.TimeoutError,
            BrokenPipeError,
        ) as e:
            print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
            time.sleep(1)


def reconnect_if_needed(redis_client):
    try:
        redis_client.ping()
    except (
        redis.exceptions.ConnectionError,
        redis.exceptions.TimeoutError,
        BrokenPipeError,
    ):
        print("Reconnecting to Redis...")
        return connect_to_redis()
    return redis_client


def load_object_from_redis(key):
    redis_client = connect_to_redis()
    redis_client = reconnect_if_needed(redis_client)
    try:
        obj_data = redis_client.get(key)
        return pickle.loads(obj_data) if obj_data else None
    except (pickle.PickleError, redis.exceptions.RedisError) as e:
        print(f"Failed to load object from Redis: {e}")
        return None


def save_object_to_redis(key, obj):
    redis_client = connect_to_redis()
    redis_client = reconnect_if_needed(redis_client)
    try:
        redis_client.set(key, pickle.dumps(obj))
    except redis.exceptions.RedisError as e:
        print(f"Failed to save object to Redis: {e}")


def upload_to_gcs(bucket_name, blob_name, data):
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(blob_name)
    blob.upload_from_string(data)


def download_from_gcs(bucket_name, blob_name):
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(blob_name)
    return blob.download_as_bytes()


def get_model_or_download(model_id, redis_key, loader_func):
    model = load_object_from_redis(redis_key)
    if model:
        return model
    try:
        with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
            model = loader_func(model_id, torch_dtype=torch.float16)
            pbar.update(1)
        save_object_to_redis(redis_key, model)
        model_bytes = pickle.dumps(model)
        upload_to_gcs(gcs_bucket_name, redis_key, model_bytes)
        return model
    except Exception as e:
        print(f"Failed to load or save model: {e}")
        return None


def generate_image(prompt):
    redis_key = f"generated_image:{prompt}"
    image_bytes = load_object_from_redis(redis_key)
    if not image_bytes:
        try:
            with tqdm(total=1, desc="Generating image") as pbar:
                image = text_to_image_pipeline(prompt).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            save_object_to_redis(redis_key, image_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
        except Exception as e:
            print(f"Failed to generate image: {e}")
            return None
    return image_bytes


def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
    redis_key = f"edited_image:{prompt}:{strength}"
    edited_image_bytes = load_object_from_redis(redis_key)
    if not edited_image_bytes:
        try:
            image = Image.open(io.BytesIO(image_bytes))
            with tqdm(total=1, desc="Editing image") as pbar:
                edited_image = img2img_pipeline(
                    prompt=prompt, image=image, strength=strength
                ).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            edited_image.save(buffered, format="JPEG")
            edited_image_bytes = buffered.getvalue()
            save_object_to_redis(redis_key, edited_image_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, edited_image_bytes)
        except Exception as e:
            print(f"Failed to edit image: {e}")
            return None
    return edited_image_bytes


def generate_song(prompt, duration=10):
    redis_key = f"generated_song:{prompt}:{duration}"
    song_bytes = load_object_from_redis(redis_key)
    if not song_bytes:
        try:
            with tqdm(total=1, desc="Generating song") as pbar:
                song = music_gen(prompt, duration=duration)
                pbar.update(1)
            song_bytes = song[0].getvalue()
            save_object_to_redis(redis_key, song_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
        except Exception as e:
            print(f"Failed to generate song: {e}")
            return None
    return song_bytes


def generate_text(prompt):
    redis_key = f"generated_text:{prompt}"
    text = load_object_from_redis(redis_key)
    if not text:
        try:
            with tqdm(total=1, desc="Generating text") as pbar:
                text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
                    "generated_text"
                ].strip()
                pbar.update(1)
            save_object_to_redis(redis_key, text)
            upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
        except Exception as e:
            print(f"Failed to generate text: {e}")
            return None
    return text


def generate_flux_image(prompt):
    redis_key = f"generated_flux_image:{prompt}"
    flux_image_bytes = load_object_from_redis(redis_key)
    if not flux_image_bytes:
        try:
            with tqdm(total=1, desc="Generating FLUX image") as pbar:
                flux_image = flux_pipeline(
                    prompt,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_length=256,
                    generator=torch.Generator("cpu").manual_seed(0),
                ).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            flux_image.save(buffered, format="JPEG")
            flux_image_bytes = buffered.getvalue()
            save_object_to_redis(redis_key, flux_image_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, flux_image_bytes)
        except Exception as e:
            print(f"Failed to generate flux image: {e}")
            return None
    return flux_image_bytes


def generate_code(prompt):
    redis_key = f"generated_code:{prompt}"
    code = load_object_from_redis(redis_key)
    if not code:
        try:
            with tqdm(total=1, desc="Generating code") as pbar:
                inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to(
                    starcoder_model.device
                )
                outputs = starcoder_model.generate(inputs, max_new_tokens=256)
                code = starcoder_tokenizer.decode(outputs[0])
                pbar.update(1)
            save_object_to_redis(redis_key, code)
            upload_to_gcs(gcs_bucket_name, redis_key, code.encode())
        except Exception as e:
            print(f"Failed to generate code: {e}")
            return None
    return code


def test_model_meta_llama():
    redis_key = "meta_llama_test_response"
    response = load_object_from_redis(redis_key)
    if not response:
        try:
            messages = [
                {
                    "role": "system",
                    "content": "You are a pirate chatbot who always responds in pirate speak!",
                },
                {"role": "user", "content": "Who are you?"},
            ]
            with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
                response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
                    "generated_text"
                ].strip()
                pbar.update(1)
            save_object_to_redis(redis_key, response)
            upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
        except Exception as e:
            print(f"Failed to test Meta-Llama: {e}")
            return None
    return response


def generate_image_sdxl(prompt):
    redis_key = f"generated_image_sdxl:{prompt}"
    image_bytes = load_object_from_redis(redis_key)
    if not image_bytes:
        try:
            with tqdm(total=1, desc="Generating SDXL image") as pbar:
                image = base(
                    prompt=prompt,
                    num_inference_steps=40,
                    denoising_end=0.8,
                    output_type="latent",
                ).images
                image = refiner(
                    prompt=prompt,
                    num_inference_steps=40,
                    denoising_start=0.8,
                    image=image,
                ).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            save_object_to_redis(redis_key, image_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, image_bytes)
        except Exception as e:
            print(f"Failed to generate SDXL image: {e}")
            return None
    return image_bytes


def generate_musicgen_melody(prompt):
    redis_key = f"generated_musicgen_melody:{prompt}"
    song_bytes = load_object_from_redis(redis_key)
    if not song_bytes:
        try:
            with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
                melody, sr = torchaudio.load("./assets/bach.mp3")
                wav = music_gen_melody.generate_with_chroma(
                    [prompt], melody[None].expand(3, -1, -1), sr
                )
                pbar.update(1)
            song_bytes = wav[0].getvalue()
            save_object_to_redis(redis_key, song_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
        except Exception as e:
            print(f"Failed to generate MusicGen melody: {e}")
            return None
    return song_bytes


def generate_musicgen_large(prompt):
    redis_key = f"generated_musicgen_large:{prompt}"
    song_bytes = load_object_from_redis(redis_key)
    if not song_bytes:
        try:
            with tqdm(total=1, desc="Generating MusicGen large") as pbar:
                wav = music_gen_large.generate([prompt])
                pbar.update(1)
            song_bytes = wav[0].getvalue()
            save_object_to_redis(redis_key, song_bytes)
            upload_to_gcs(gcs_bucket_name, redis_key, song_bytes)
        except Exception as e:
            print(f"Failed to generate MusicGen large: {e}")
            return None
    return song_bytes


def transcribe_audio(audio_sample):
    redis_key = f"transcribed_audio:{hash(audio_sample.tobytes())}"
    text = load_object_from_redis(redis_key)
    if not text:
        try:
            with tqdm(total=1, desc="Transcribing audio") as pbar:
                text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
                pbar.update(1)
            save_object_to_redis(redis_key, text)
            upload_to_gcs(gcs_bucket_name, redis_key, text.encode())
        except Exception as e:
            print(f"Failed to transcribe audio: {e}")
            return None
    return text


def generate_mistral_instruct(prompt):
    redis_key = f"generated_mistral_instruct:{prompt}"
    response = load_object_from_redis(redis_key)
    if not response:
        try:
            conversation = [{"role": "user", "content": prompt}]
            with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
                inputs = mistral_instruct_tokenizer.apply_chat_template(
                    conversation,
                    tools=tools,
                    add_generation_prompt=True,
                    return_dict=True,
                    return_tensors="pt",
                )
                inputs.to(mistral_instruct_model.device)
                outputs = mistral_instruct_model.generate(
                    **inputs, max_new_tokens=1000
                )
                response = mistral_instruct_tokenizer.decode(
                    outputs[0], skip_special_tokens=True
                )
                pbar.update(1)
            save_object_to_redis(redis_key, response)
            upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
        except Exception as e:
            print(f"Failed to generate Mistral Instruct response: {e}")
            return None
    return response


def generate_mistral_nemo(prompt):
    redis_key = f"generated_mistral_nemo:{prompt}"
    response = load_object_from_redis(redis_key)
    if not response:
        try:
            conversation = [{"role": "user", "content": prompt}]
            with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
                inputs = mistral_nemo_tokenizer.apply_chat_template(
                    conversation,
                    tools=tools,
                    add_generation_prompt=True,
                    return_dict=True,
                    return_tensors="pt",
                )
                inputs.to(mistral_nemo_model.device)
                outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
                response = mistral_nemo_tokenizer.decode(
                    outputs[0], skip_special_tokens=True
                )
                pbar.update(1)
            save_object_to_redis(redis_key, response)
            upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
        except Exception as e:
            print(f"Failed to generate Mistral Nemo response: {e}")
            return None
    return response


def generate_gpt2_xl(prompt):
    redis_key = f"generated_gpt2_xl:{prompt}"
    response = load_object_from_redis(redis_key)
    if not response:
        try:
            with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
                inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
                outputs = gpt2_xl_model(**inputs)
                response = gpt2_xl_tokenizer.decode(
                    outputs[0][0], skip_special_tokens=True
                )
                pbar.update(1)
            save_object_to_redis(redis_key, response)
            upload_to_gcs(gcs_bucket_name, redis_key, response.encode())
        except Exception as e:
            print(f"Failed to generate GPT-2 XL response: {e}")
            return None
    return response


def answer_question_minicpm(image_bytes, question):
    redis_key = f"minicpm_answer:{hash(image_bytes)}:{question}"
    answer = load_object_from_redis(redis_key)
    if not answer:
        try:
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            with tqdm(total=1, desc="Answering question with MiniCPM") as pbar:
                msgs = [{"role": "user", "content": [image, question]}]
                answer = minicpm_model.chat(
                    image=None, msgs=msgs, tokenizer=minicpm_tokenizer
                )
                pbar.update(1)
            save_object_to_redis(redis_key, answer)
            upload_to_gcs(gcs_bucket_name, redis_key, answer.encode())
        except Exception as e:
            print(f"Failed to answer question with MiniCPM: {e}")
            return None
    return answer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

text_to_image_pipeline = get_model_or_download(
    "stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained
)
img2img_pipeline = get_model_or_download(
    "CompVis/stable-diffusion-v1-4",
    "img2img_model",
    StableDiffusionImg2ImgPipeline.from_pretrained,
)
flux_pipeline = get_model_or_download(
    "black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained
)
text_gen_pipeline = transformers_pipeline(
    "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
)
music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained(
    "melody"
).to(device)
meta_llama_pipeline = get_model_or_download(
    "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline
)
starcoder_model = AutoModelForCausalLM.from_pretrained(
    "bigcode/starcoder"
).to(device)
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")

base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
).to(device)
refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to(device)
music_gen_melody = musicgen.MusicGen.get_pretrained("melody").to(device)
music_gen_melody.set_generation_params(duration=8)
music_gen_large = musicgen.MusicGen.get_pretrained("large").to(device)
music_gen_large.set_generation_params(duration=8)
whisper_pipeline = transformers_pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-small",
    chunk_length_s=30,
    device=device,
)
mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-Large-Instruct-2407",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-Large-Instruct-2407"
)
mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-Nemo-Instruct-2407",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-Nemo-Instruct-2407"
)
gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
minicpm_model = AutoModel.from_pretrained(
    "openbmb/MiniCPM-V-2_6",
    trust_remote_code=True,
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
).eval().cuda()
minicpm_tokenizer = AutoTokenizer.from_pretrained(
    "openbmb/MiniCPM-V-2_6", trust_remote_code=True
)

tools = []  # Define any tools needed for Mistral models

gen_image_tab = gr.Interface(
    fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image"
)
edit_image_tab = gr.Interface(
    fn=edit_image_with_prompt,
    inputs=[
        gr.Image(type="pil", label="Image:"),
        gr.Textbox(label="Prompt:"),
        gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
    ],
    outputs=gr.Image(type="pil"),
    title="Edit Image",
)
generate_song_tab = gr.Interface(
    fn=generate_song,
    inputs=[
        gr.Textbox(label="Prompt:"),
        gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
    ],
    outputs=gr.Audio(type="numpy"),
    title="Generate Songs",
)
generate_text_tab = gr.Interface(
    fn=generate_text,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Generated Text:"),
    title="Generate Text",
)
generate_flux_image_tab = gr.Interface(
    fn=generate_flux_image,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Image(type="pil"),
    title="Generate FLUX Images",
)
generate_code_tab = gr.Interface(
    fn=generate_code,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Generated Code:"),
    title="Generate Code",
)
model_meta_llama_test_tab = gr.Interface(
    fn=test_model_meta_llama,
    inputs=None,
    outputs=gr.Textbox(label="Model Output:"),
    title="Test Meta-Llama",
)
generate_image_sdxl_tab = gr.Interface(
    fn=generate_image_sdxl,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Image(type="pil"),
    title="Generate SDXL Image",
)
generate_musicgen_melody_tab = gr.Interface(
    fn=generate_musicgen_melody,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Audio(type="numpy"),
    title="Generate MusicGen Melody",
)
generate_musicgen_large_tab = gr.Interface(
    fn=generate_musicgen_large,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Audio(type="numpy"),
    title="Generate MusicGen Large",
)
transcribe_audio_tab = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(type="numpy", label="Audio Sample:"),
    outputs=gr.Textbox(label="Transcribed Text:"),
    title="Transcribe Audio",
)
generate_mistral_instruct_tab = gr.Interface(
    fn=generate_mistral_instruct,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Mistral Instruct Response:"),
    title="Generate Mistral Instruct Response",
)
generate_mistral_nemo_tab = gr.Interface(
    fn=generate_mistral_nemo,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Mistral Nemo Response:"),
    title="Generate Mistral Nemo Response",
)
generate_gpt2_xl_tab = gr.Interface(
    fn=generate_gpt2_xl,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="GPT-2 XL Response:"),
    title="Generate GPT-2 XL Response",
)
answer_question_minicpm_tab = gr.Interface(
    fn=answer_question_minicpm,
    inputs=[
        gr.Image(type="pil", label="Image:"),
        gr.Textbox(label="Question:"),
    ],
    outputs=gr.Textbox(label="MiniCPM Answer:"),
    title="Answer Question with MiniCPM",
)

app = gr.TabbedInterface(
    [
        gen_image_tab,
        edit_image_tab,
        generate_song_tab,
        generate_text_tab,
        generate_flux_image_tab,
        generate_code_tab,
        model_meta_llama_test_tab,
        generate_image_sdxl_tab,
        generate_musicgen_melody_tab,
        generate_musicgen_large_tab,
        transcribe_audio_tab,
        generate_mistral_instruct_tab,
        generate_mistral_nemo_tab,
        generate_gpt2_xl_tab,
        answer_question_minicpm_tab,
    ],
    [
        "Generate Image",
        "Edit Image",
        "Generate Song",
        "Generate Text",
        "Generate FLUX Image",
        "Generate Code",
        "Test Meta-Llama",
        "Generate SDXL Image",
        "Generate MusicGen Melody",
        "Generate MusicGen Large",
        "Transcribe Audio",
        "Generate Mistral Instruct Response",
        "Generate Mistral Nemo Response",
        "Generate GPT-2 XL Response",
        "Answer Question with MiniCPM",
    ],
)

app.launch(share=True)