import argparse, subprocess, sys, time def setup(): install_cmds = [ ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'stability-sdk', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'], ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'], ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'], ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git'] ] for cmd in install_cmds: print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8')) setup() sys.path.append('src/blip') sys.path.append('src/clip') sys.path.append('clip-interrogator') import clip import torch import gradio as gr from clip_interrogator import Interrogator, Config ci = Interrogator(Config()) import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation from stability_sdk import client import os stability_api = client.StabilityInference( key=os.environ['STABILITY_KEY'], verbose=True ) import clip import torch from PIL import Image import warnings import random from io import BytesIO def inferAndRebuild(image, mode): image = image.convert('RGB') output = '' if (mode == 'best'): output = ci.interrogate(image) elif (mode == 'classic'): output = ci.interrogate_classic(image) else: output = ci.interrogate_fast(image) answers = stability_api.generate( prompt=str(output), seed=34567, steps=30, samples=5 ) imglist = [] for resp in answers: for artifact in resp.artifacts: if artifact.finish_reason == generation.FILTER: warnings.warn( "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again.") if artifact.type == generation.ARTIFACT_IMAGE: img = Image.open(BytesIO(artifact.binary)) imglist.append(img) return [imglist, output] inputs = [ gr.inputs.Image(type='pil'), gr.Radio(['best', 'classic', 'fast'], label='Models', value='fast') ] outputs = [ gr.Gallery(), gr.outputs.Textbox(label='Prompt') ] io = gr.Interface( inferAndRebuild, inputs, outputs, allow_flagging=False, ) io.launch(debug=True)