Spaces:
Running
Running
from fastapi import APIRouter, Depends | |
from fastapi.responses import StreamingResponse | |
from PIL import Image, ImageEnhance | |
from fastapi import HTTPException | |
import io | |
import requests | |
import os | |
import base64 | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
from pymongo import MongoClient | |
from models import * | |
from huggingface_hub import InferenceClient | |
from fastapi import UploadFile | |
from fastapi.responses import JSONResponse | |
import uuid | |
from RyuzakiLib import GeminiLatest | |
class FluxAI(BaseModel): | |
user_id: int = 1191668125 | |
api_key: str | |
args: str | |
auto_enhancer: bool = False | |
is_flux_dev: bool = False | |
class MistralAI(BaseModel): | |
args: str | |
router = APIRouter() | |
load_dotenv() | |
MONGO_URL = os.environ["MONGO_URL"] | |
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] | |
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"] | |
client_mongo = MongoClient(MONGO_URL) | |
db = client_mongo["tiktokbot"] | |
collection = db["users"] | |
async def schellwithflux(args): | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" | |
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} | |
payload = {"inputs": args} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"Error status {response.status_code}") | |
return None | |
return response.content | |
async def devwithflux(args): | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" | |
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} | |
payload = {"inputs": args} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"Error status {response.status_code}") | |
return None | |
return response.content | |
async def mistralai_post_message(message_str): | |
client = InferenceClient( | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
token=HUGGING_TOKEN | |
) | |
output = "" | |
for message in client.chat_completion( | |
messages=[{"role": "user", "content": message_str}], | |
max_tokens=500, | |
stream=True | |
): | |
output += message.choices[0].delta.content | |
return output | |
def get_user_tokens_gpt(user_id): | |
user = collection.find_one({"user_id": user_id}) | |
if not user: | |
return 0 | |
return user.get("tokens", 0) | |
def deduct_tokens_gpt(user_id, amount): | |
tokens = get_user_tokens_gpt(user_id) | |
if tokens >= amount: | |
collection.update_one( | |
{"user_id": user_id}, | |
{"$inc": {"tokens": -amount}} | |
) | |
return True | |
else: | |
return False | |
async def get_token_with_flux(user_id: int): | |
tokens = get_user_tokens_gpt(user_id) | |
if tokens: | |
return SuccessResponse( | |
status="True", | |
randydev={"tokens": f"Current tokens: {tokens}."} | |
) | |
else: | |
return SuccessResponse( | |
status="False", | |
randydev={"tokens": f"Not enough tokens. Current tokens: {tokens}."} | |
) | |
async def mistralai_(payload: MistralAI): | |
try: | |
response = await mistralai_post_message(payload.args) | |
return SuccessResponse( | |
status="True", | |
randydev={"message": response} | |
) | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
def get_all_api_keys(): | |
user = collection.find({}) | |
api_keys = [] | |
for x in user: | |
api_key = x.get("ryuzaki_api_key") | |
if api_key: | |
api_keys.append(api_key) | |
return api_keys | |
async def fluxai_image(payload: FluxAI): | |
if deduct_tokens_gpt(payload.user_id, amount=20): | |
USERS_API_KEYS = get_all_api_keys() | |
if payload.api_key in USERS_API_KEYS: | |
try: | |
if payload.is_flux_dev: | |
image_bytes = await devwithflux(payload.args) | |
else: | |
image_bytes = await schellwithflux(payload.args) | |
if image_bytes is None: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": "Failed to generate an image"} | |
) | |
if payload.auto_enhancer: | |
with Image.open(io.BytesIO(image_bytes)) as image: | |
enhancer = ImageEnhance.Sharpness(image) | |
image = enhancer.enhance(1.5) | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(1.2) | |
enhancer = ImageEnhance.Color(image) | |
image = enhancer.enhance(1.1) | |
enhanced_image_bytes = "akeno.jpg" | |
image.save(enhanced_image_bytes, format="JPEG", quality=95) | |
with open(enhanced_image_bytes, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
example_test = "Accurately identify the baked good in the image and provide an appropriate and recipe consistent with your analysis." | |
x = GeminiLatest(api_keys=GOOGLE_API_KEY) | |
response = x.get_response_image(example_test, enhanced_image_bytes) | |
return SuccessResponse( | |
status="True", | |
randydev={"image_data": encoded_string, "caption": response} | |
) | |
else: | |
return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg") | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
else: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"Error required api_key"} | |
) | |
else: | |
tokens = get_user_tokens_gpt(payload.user_id) | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"Not enough tokens. Current tokens: {tokens} and required api_key."} | |
) | |