import requests
from flask import Flask, request, Response, redirect, jsonify, stream_with_context
import random, subprocess, string, json
import os
import logging
from logging import NullHandler
import subprocess
from functools import wraps

# Copied pepsi (thanks <3) made some modification so it works for me :v hue hue, nah doesn't use that shitty api thingy :v world sim whatever >.>

def validate_token(token):
    if token == os.environ.get('PASSWORD'):
        return True
    else:
        return False

def requires_auth_bearer(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        if 'X-Api-Key' not in request.headers:
            return jsonify({'error': 'You need special Saya password :3c'}), 401
        token = request.headers['X-Api-Key']
        if not validate_token(token):
            return jsonify({'error': 'hue hue hue hue hue hue '}), 401
        return f(*args, **kwargs)
    return decorated

app = Flask(__name__)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True

state = {
    "current_requests": 0,
    "total_prompts_sent": 0,
    "total_tokens": 0,
    "logging_enabled": os.environ.get('LOGGING_ENABLED', 'False').lower() in ['true', '1', 't']
}

## ill ne ver turn on logging i just wanted to show its off on page (same pepsi same :V)
def configure_logging():
    app.logger.addHandler(NullHandler())
    app.logger.propagate = False
    bloody_log = logging.getLogger('bloody')
    bloody_log.setLevel(logging.ERROR)
    bloody_log.addHandler(NullHandler())

configure_logging()

def get_token_count(text):
    result = subprocess.run(['node', 'tokenizer.js', text], capture_output=True, text=True)
    return int(result.stdout.strip())

def generate_message_id():
    chars = string.ascii_uppercase + string.digits + string.ascii_lowercase
    return 'msg_' + ''.join(random.choices(chars, k=16))


@app.errorhandler(404)
def page_not_found(e):
    return redirect(random.choice([
        'https://youtu.be/0o0y-MNqdQU',
        'https://youtu.be/oAJC1Pn78ZA',
        'https://youtu.be/p_5tTM9D7l0',
        'https://youtu.be/H0gEjUEneBI',
        'https://youtu.be/Hj8icpQldzc',
        'https://youtu.be/-9_sTTYXcwc',
        'https://youtu.be/LmsuxO5rfEU',
        'https://youtu.be/VJkzfV7kNYQ',
        'https://youtu.be/oCikD1xcv0o',
        'https://youtu.be/k9TSVx9gAW0',
        'https://youtu.be/Xiiy8vEWj-g',
        'https://youtu.be/FKLd1YdmIwA',
        'https://youtu.be/RJ4iaQAF6SI',
        'https://youtu.be/KPad2ftEwqc'
    ])), 302

@app.route('/saya/messages', methods=['POST'])
@requires_auth_bearer
def proxy_message():
    state["total_prompts_sent"] += 1
    state["current_requests"] += 1

    data = request.get_json()
    if not data or 'messages' not in data:
        abort(400, 'Bad Request: No messages found in the request.')
    full_text = ' '.join(msg.get('content', '') for msg in data['messages'])
    token_count = get_token_count(full_text)
    
    data["model"] = os.environ.get('MODEL')

    headers = {
        "Authorization": f"Bearer {os.environ.get('KEY')}",
        'Content-Type': 'application/json',
        'User-Agent': 'Mozilla/5.0 (iPhone; CPU iPhone OS 12_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) FxiOS/13.2b11866 Mobile/16A366 Safari/605.1.15',
        'Origin': os.environ.get('ORIGIN'),
        'Referer': os.environ.get('REFERER'),
        "Accept": "*/*",
        "Accept-Language": "en-US,en;q=0.5",
        "Sec-Fetch-Dest": "empty",
        "Sec-Fetch-Mode": "cors",
        "Sec-Fetch-Site": "same-origin",
        "Sec-GPC": "1"
    }
    proxied_url = os.environ.get('PROXIED_URL')
    response = requests.post(proxied_url, headers=headers, data=json.dumps(data), stream=True)

    if not data.get('stream', False):
        output = ''.join(chunk.decode('utf-8') for chunk in response.iter_content(chunk_size=1024) if chunk)
        response_json = {
            "content": [{"text": output, "type": "text"}],
            "id": generate_message_id(),
            "model": ":^)",
            "role": "assistant",
            "stop_reason": "end_turn",
            "stop_sequence": None,
            "type": "message",
            "usage": {
                "input_tokens": token_count,
                "output_tokens": 25,
            }
        }
        state["current_requests"] -= 1
        return Response(json.dumps(response_json), headers={'Content-Type': 'application/json'}, status=response.status_code)

    @stream_with_context
    def generate_stream():
        tokens_sent = 0
        text_buffer = ''
        try:
            yield 'event: message_start\n'
            yield f'data: {{"type": "message_start", "message": {{"id": "{generate_message_id()}", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {{"input_tokens": 25, "output_tokens": 1}}}}}}\n\n'
            yield 'event: content_block_start\n'
            yield 'data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}\n\n'
            incomplete_chunk=""
            for chunk in response.iter_content(chunk_size=256):
                if chunk:
                    decoded_chunk = incomplete_chunk+chunk.decode('utf-8', 'xmlcharrefreplace')
                    if "data:" in decoded_chunk:
                        if 'null}]}' not in decoded_chunk:
                            incomplete_chunk+= decoded_chunk
                        else:
                            for chunkparts in decoded_chunk.split("\n"):
                                if "data:" in chunkparts and "null}]}" in chunkparts:
                                    data = json.loads(chunkparts.replace("data: ",""))
                                    try:
                                        text_buffer += data["choices"][0]["delta"]["content"]
                                        data_dump = json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": data["choices"][0]["delta"]["content"]}})
                                        yield f'event: content_block_delta\n'
                                        yield f'data: {data_dump}\n\n'
                                        
                                        
                                        if 'content_block_stop' in chunkparts:
                                            tokens_sent += get_token_count(text_buffer)
                                            text_buffer = ''
                                        incomplete_chunk=""
                                    except:
                                        pass
                                else:
                                    incomplete_chunk=chunkparts
                    else:
                        if decoded_chunk[0] != ":":
                            incomplete_chunk+=decoded_chunk
                            
            yield 'event: content_block_stop\n'
            yield 'data: {"type": "content_block_stop", "index": 0}\n\n'
            yield 'event: message_delta\n'
            yield 'data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}\n\n'
            yield 'event: message_stop\n'
            yield 'data: {"type": "message_stop"}\n\n'
        except GeneratorExit:
            pass 
        finally:
            if text_buffer:
                tokens_sent += get_token_count(text_buffer)
            state["current_requests"] -= 1
            state["total_tokens"] += tokens_sent

    return Response(generate_stream(), headers={'Content-Type': 'text/event-stream'})

@app.route('/') 
def index():
    space_host = os.environ.get('SPACE_HOST', 'default-space-host')
    endpoint = f"https://{space_host}/saya"
    payload = '''<body style="background-image: url('https://images.dragonetwork.pl/wp12317828.jpg'); background-size: cover; background-repeat: no-repeat;">
    <div style="background-color: rgba(0, 0, 0, 0.5); padding: 20px;">
        <p style="color: white; font-weight: bold; font-size: 200%;">
            <span>Endpoint:</span> <span>''' + endpoint + '''</span><br>
            <span>Prompt Logging:</span> <span>''' + str(state["logging_enabled"]) + '''</span><br>
            <span>Prompters Now:</span> <span>''' + str(state["current_requests"]) + '''</span><br>
            <span>Total Prompts:</span> <span>''' + str(state["total_prompts_sent"]) + '''</span><br>
            <span>Total Tokens:</span> <span>''' + str(state["total_tokens"]) + '''</span><br>
            <span>Password Protected?:</span> <span>yes :3c</span><br>
        </p>
    </div>
</body>'''
    return payload, 200

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=False)