Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| """ | |
| use clip to interrogate image(s) | |
| """ | |
| import io | |
| import base64 | |
| import sys | |
| import os | |
| import asyncio | |
| import filetype | |
| from PIL import Image | |
| from util import log, Map | |
| import sdapi | |
| stats = { 'captions': {}, 'keywords': {} } | |
| exclude = ['a', 'in', 'on', 'out', 'at', 'the', 'and', 'with', 'next', 'to', 'it', 'for', 'of', 'into', 'that'] | |
| def decode(encoding): | |
| if encoding.startswith("data:image/"): | |
| encoding = encoding.split(";")[1].split(",")[1] | |
| return Image.open(io.BytesIO(base64.b64decode(encoding))) | |
| def encode(f): | |
| image = Image.open(f) | |
| exif = image.getexif() | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| with io.BytesIO() as stream: | |
| image.save(stream, 'JPEG', exif = exif) | |
| values = stream.getvalue() | |
| encoded = base64.b64encode(values).decode() | |
| return encoded | |
| def print_summary(): | |
| captions = dict(sorted(stats['captions'].items(), key=lambda x:x[1], reverse=True)) | |
| log.info({ 'caption stats': captions }) | |
| keywords = dict(sorted(stats['keywords'].items(), key=lambda x:x[1], reverse=True)) | |
| log.info({ 'keyword stats': keywords }) | |
| async def interrogate(f): | |
| if not filetype.is_image(f): | |
| log.info({ 'interrogate skip': f }) | |
| return | |
| json = Map({ 'image': encode(f) }) | |
| log.info({ 'interrogate': f }) | |
| # run clip | |
| json.model = 'clip' | |
| res = await sdapi.post('/sdapi/v1/interrogate', json) | |
| caption = "" | |
| style = "" | |
| if 'caption' in res: | |
| caption = res.caption | |
| log.info({ 'interrogate caption': caption }) | |
| if ', by' in caption: | |
| style = caption.split(', by')[1].strip() | |
| log.info({ 'interrogate style': style }) | |
| for word in caption.split(' '): | |
| if word not in exclude: | |
| stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1 | |
| else: | |
| log.error({ 'interrogate clip error': res }) | |
| # run booru | |
| json.model = 'deepdanbooru' | |
| res = await sdapi.post('/sdapi/v1/interrogate', json) | |
| keywords = {} | |
| if 'caption' in res: | |
| for term in res.caption.split(', '): | |
| term = term.replace('(', '').replace(')', '').replace('\\', '').split(':') | |
| if len(term) < 2: | |
| continue | |
| keywords[term[0]] = term[1] | |
| keywords = dict(sorted(keywords.items(), key=lambda x:x[1], reverse=True)) | |
| for word in keywords.items(): | |
| stats['keywords'][word[0]] = stats['keywords'][word[0]] + 1 if word[0] in stats['keywords'] else 1 | |
| log.info({ 'interrogate keywords': keywords }) | |
| else: | |
| log.error({ 'interrogate booru error': res }) | |
| return caption, keywords, style | |
| async def main(): | |
| sys.argv.pop(0) | |
| await sdapi.session() | |
| if len(sys.argv) == 0: | |
| log.error({ 'interrogate': 'no files specified' }) | |
| for arg in sys.argv: | |
| if os.path.exists(arg): | |
| if os.path.isfile(arg): | |
| await interrogate(arg) | |
| elif os.path.isdir(arg): | |
| for root, _dirs, files in os.walk(arg): | |
| for f in files: | |
| _caption, _keywords, _style = await interrogate(os.path.join(root, f)) | |
| else: | |
| log.error({ 'interrogate unknown file type': arg }) | |
| else: | |
| log.error({ 'interrogate file missing': arg }) | |
| await sdapi.close() | |
| print_summary() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |