from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse from PIL import Image, ImageEnhance from fastapi import HTTPException import io from io import BytesIO import requests import os import base64 from dotenv import load_dotenv from pydantic import BaseModel from pymongo import MongoClient from models import * from huggingface_hub import InferenceClient from fastapi import UploadFile, File from fastapi.responses import JSONResponse, FileResponse import uuid from RyuzakiLib import GeminiLatest class FluxAI(BaseModel): user_id: int args: str auto_enhancer: bool = False class MistralAI(BaseModel): args: str router = APIRouter() load_dotenv() MONGO_URL = os.environ["MONGO_URL"] HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"] client_mongo = MongoClient(MONGO_URL) db = client_mongo["tiktokbot"] collection = db["users"] async def schellwithflux(args): API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} payload = {"inputs": args} response = requests.post(API_URL, headers=headers, json=payload) if response.status_code != 200: print(f"Error status {response.status_code}") return None return response.content async def mistralai_post_message(message_str): client = InferenceClient( "mistralai/Mixtral-8x7B-Instruct-v0.1", token=HUGGING_TOKEN ) output = "" for message in client.chat_completion( messages=[{"role": "user", "content": message_str}], max_tokens=500, stream=True ): output += message.choices[0].delta.content return output def get_user_tokens_gpt(user_id): user = collection.find_one({"user_id": user_id}) if not user: return 0 return user.get("tokens", 0) def deduct_tokens_gpt(user_id, amount): tokens = get_user_tokens_gpt(user_id) if tokens >= amount: collection.update_one( {"user_id": user_id}, {"$inc": {"tokens": -amount}} ) return True else: return False @router.post("/akeno/mistralai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}}) async def mistralai_(payload: MistralAI): try: response = await mistralai_post_message(payload.args) return SuccessResponse( status="True", randydev={"message": response} ) except Exception as e: return SuccessResponse( status="False", randydev={"error": f"An error occurred: {str(e)}"} ) @router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}}) async def fluxai_image(payload: FluxAI): if deduct_tokens_gpt(payload.user_id, amount=20): try: image_bytes = await schellwithflux(payload.args) if image_bytes is None: return SuccessResponse( status="False", randydev={"error": "Failed to generate an image"} ) if payload.auto_enhancer: with Image.open(BytesIO(image_bytes)) as image: enhancer = ImageEnhance.Sharpness(image) image = enhancer.enhance(1.5) enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(1.2) enhancer = ImageEnhance.Color(image) image = enhancer.enhance(1.1) enhanced_image_bytes = "akeno.jpg" image.save(enhanced_image_bytes, format="JPEG", quality=95) with open(enhanced_image_bytes, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()) example_test = "Explain how this picture looks like." x = GeminiLatest(api_keys=GOOGLE_API_KEY) response = x.get_response_image(example_test, enhanced_image_bytes) return SuccessResponse( status="True", randydev={"image_data": encoded_string, "caption": response} ) else: return StreamingResponse(BytesIO(image_bytes), media_type="image/jpeg") except Exception as e: return SuccessResponse( status="False", randydev={"error": f"An error occurred: {str(e)}"} ) else: tokens = get_user_tokens_gpt(payload.user_id) return SuccessResponse( status="False", randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"} )