File size: 2,380 Bytes
f7516d1
e827598
 
 
f4476fb
e827598
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac3002
 
8ae5b18
 
e827598
8ae5b18
e827598
523eec4
 
f7516d1
523eec4
e827598
 
 
 
 
e84336d
 
ef3d47b
f7516d1
 
 
e827598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84336d
e827598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
017db32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse, subprocess, sys, time

def setup():
    install_cmds = [
        ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'stability-sdk', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],
        ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],
        ['pip', 'install', '-e',
            'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],
        ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']
    ]
    for cmd in install_cmds:
        print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))

setup()

sys.path.append('src/blip')
sys.path.append('src/clip')
sys.path.append('clip-interrogator')

import clip
import torch
import gradio as gr
from clip_interrogator import Interrogator, Config

ci = Interrogator(Config())

import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from stability_sdk import client
import os

stability_api = client.StabilityInference(
    key=os.environ['STABILITY_KEY'],
    verbose=True
)

import clip
import torch
from PIL import Image
import warnings
import random
from io import BytesIO

def inferAndRebuild(image, mode):
    image = image.convert('RGB')
    output = ''
    if (mode == 'best'):
        output = ci.interrogate(image)
    elif (mode == 'classic'):
        output = ci.interrogate_classic(image)
    else:
        output = ci.interrogate_fast(image)

    answers = stability_api.generate(
        prompt=str(output),
        seed=34567,
        steps=30,
        samples=5
    )

    imglist = []
    for resp in answers:
        for artifact in resp.artifacts:
            if artifact.finish_reason == generation.FILTER:
                warnings.warn(
                    "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again.")
            if artifact.type == generation.ARTIFACT_IMAGE:
                img = Image.open(BytesIO(artifact.binary))
                imglist.append(img)
    return [imglist, output]


inputs = [
    gr.inputs.Image(type='pil'),
    gr.Radio(['best', 'classic', 'fast'], label='Models', value='fast')
]

outputs = [
    gr.Gallery(),
    gr.outputs.Textbox(label='Prompt')
]

io = gr.Interface(
    inferAndRebuild,
    inputs,
    outputs,
    allow_flagging=False,
)

io.launch(debug=True)