ryuzaki-api / fluxai.py
randydev's picture
Update fluxai.py
4015bf1
raw
history blame
6.53 kB
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from PIL import Image, ImageEnhance
from fastapi import HTTPException
import io
import requests
import os
import base64
from dotenv import load_dotenv
from pydantic import BaseModel
from pymongo import MongoClient
from models import *
from huggingface_hub import InferenceClient
from fastapi import UploadFile
from fastapi.responses import JSONResponse
import uuid
from RyuzakiLib import GeminiLatest
class FluxAI(BaseModel):
user_id: int = 1191668125
api_key: str
args: str
auto_enhancer: bool = False
is_flux_dev: bool = False
class MistralAI(BaseModel):
args: str
router = APIRouter()
load_dotenv()
MONGO_URL = os.environ["MONGO_URL"]
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"]
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
client_mongo = MongoClient(MONGO_URL)
db = client_mongo["tiktokbot"]
collection = db["users"]
async def schellwithflux(args):
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"}
payload = {"inputs": args}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
print(f"Error status {response.status_code}")
return None
return response.content
async def devwithflux(args):
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"}
payload = {"inputs": args}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
print(f"Error status {response.status_code}")
return None
return response.content
async def mistralai_post_message(message_str):
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
token=HUGGING_TOKEN
)
output = ""
for message in client.chat_completion(
messages=[{"role": "user", "content": message_str}],
max_tokens=500,
stream=True
):
output += message.choices[0].delta.content
return output
def get_user_tokens_gpt(user_id):
user = collection.find_one({"user_id": user_id})
if not user:
return 0
return user.get("tokens", 0)
def deduct_tokens_gpt(user_id, amount):
tokens = get_user_tokens_gpt(user_id)
if tokens >= amount:
collection.update_one(
{"user_id": user_id},
{"$inc": {"tokens": -amount}}
)
return True
else:
return False
@router.get("/akeno/gettoken")
async def get_token_with_flux(user_id: int):
tokens = get_user_tokens_gpt(user_id)
if tokens:
return SuccessResponse(
status="True",
randydev={"tokens": f"Current tokens: {tokens}."}
)
else:
return SuccessResponse(
status="False",
randydev={"tokens": f"Not enough tokens. Current tokens: {tokens}."}
)
@router.post("/akeno/mistralai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}})
async def mistralai_(payload: MistralAI):
try:
response = await mistralai_post_message(payload.args)
return SuccessResponse(
status="True",
randydev={"message": response}
)
except Exception as e:
return SuccessResponse(
status="False",
randydev={"error": f"An error occurred: {str(e)}"}
)
def get_all_api_keys():
user = collection.find({})
api_keys = []
for x in user:
api_key = x.get("ryuzaki_api_key")
if api_key:
api_keys.append(api_key)
return api_keys
@router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}})
async def fluxai_image(payload: FluxAI):
if deduct_tokens_gpt(payload.user_id, amount=20):
USERS_API_KEYS = get_all_api_keys()
if payload.api_key in USERS_API_KEYS:
try:
if payload.is_flux_dev:
image_bytes = await devwithflux(payload.args)
else:
image_bytes = await schellwithflux(payload.args)
if image_bytes is None:
return SuccessResponse(
status="False",
randydev={"error": "Failed to generate an image"}
)
if payload.auto_enhancer:
with Image.open(io.BytesIO(image_bytes)) as image:
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(1.5)
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(1.2)
enhancer = ImageEnhance.Color(image)
image = enhancer.enhance(1.1)
enhanced_image_bytes = "akeno.jpg"
image.save(enhanced_image_bytes, format="JPEG", quality=95)
with open(enhanced_image_bytes, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
example_test = "Accurately identify the baked good in the image and provide an appropriate and recipe consistent with your analysis."
x = GeminiLatest(api_keys=GOOGLE_API_KEY)
response = x.get_response_image(example_test, enhanced_image_bytes)
return SuccessResponse(
status="True",
randydev={"image_data": encoded_string, "caption": response}
)
else:
return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg")
except Exception as e:
return SuccessResponse(
status="False",
randydev={"error": f"An error occurred: {str(e)}"}
)
else:
return SuccessResponse(
status="False",
randydev={"error": f"Error required api_key"}
)
else:
tokens = get_user_tokens_gpt(payload.user_id)
return SuccessResponse(
status="False",
randydev={"error": f"Not enough tokens. Current tokens: {tokens} and required api_key."}
)