Leeps's picture
Upload folder using huggingface_hub
3178eaa verified
raw
history blame
2.87 kB
import os
import base64
import numpy as np
from PIL import Image
import io
import requests
import replicate
from flask import Flask, request
import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
# Locate the .env file
dotenv_path = find_dotenv()
load_dotenv(dotenv_path)
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN')
client = OpenAI()
def call_openai(pil_image):
# Save the PIL image to a bytes buffer
buffered = io.BytesIO()
pil_image.save(buffered, format="JPEG")
# Encode the image to base64
image_data = base64.b64encode(buffered.getvalue()).decode('utf-8')
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "You are a product designer. I've attached a moodboard here. In one sentence, what do all of these elements have in common? Answer from a design language perspective, if you were telling another designer to create something similar, including any repeating colors and materials and shapes and textures"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + image_data,
},
},
],
}
],
max_tokens=300,
)
return response.choices[0].message.content
def image_classifier(moodboard, prompt):
# Convert the numpy array to a PIL image
pil_image = Image.fromarray(moodboard.astype('uint8'))
openai_response = call_openai(pil_image)
openai_response = openai_response.replace('moodboard', '')
# Call Stable Diffusion API with the response from OpenAI
input = {
"width": 768,
"height": 768,
"prompt": "high quality render of " + prompt + ", " + openai_response[20:],
"negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch",
"refine": "expert_ensemble_refiner",
"apply_watermark": False,
"num_inference_steps": 25
}
output = replicate.run(
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
input=input
)
# Download the image from the URL
image_url = output[0]
print(image_url)
response = requests.get(image_url)
print(response)
img = Image.open(io.BytesIO(response.content))
return img # Return the image object
# app = Flask(__name__)
# os.environ.get("REPLICATE_API_TOKEN")
# @app.route("/")
# def index():
demo = gr.Interface(fn=image_classifier, inputs=["image", "text"], outputs="image")
demo.launch(share=True)