PokeGen / modules /inference.py
Ron Au
feat(prompts): Use original Russian labels
ff8ee9b
raw
history blame
2.64 kB
from time import gmtime, strftime
print(f'{strftime("%Y-%m-%d %H:%M:%S", gmtime())} Preparing for inference...') # noqa
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae
from huggingface_hub import hf_hub_url, cached_download
import torch
from io import BytesIO
import base64
print(f"GPUs available: {torch.cuda.device_count()}")
print(f"GPU[0] memory: {int(torch.cuda.get_device_properties(0).total_memory / 1048576)}Mib")
print(f"GPU[0] memory reserved: {int(torch.cuda.memory_reserved(0) / 1048576)}Mib")
print(f"GPU[0] memory allocated: {int(torch.cuda.memory_allocated(0) / 1048576)}Mib")
device = "cuda" if torch.cuda.is_available() else "cpu"
fp16 = torch.cuda.is_available()
file_dir = "./models"
file_name = "pytorch_model.bin"
config_file_url = hf_hub_url(repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name)
cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)
model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)
model.load_state_dict(torch.load(f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}"))
vae = get_vae().to(device)
tokenizer = get_tokenizer()
print(f'{strftime("%Y-%m-%d %H:%M:%S", gmtime())} Ready for inference')
def english_to_russian(english_string):
word_map = {
"grass": "Покемон трава",
"fire": "Покемон огня",
"water": "Покемон в воду",
"lightning": "Покемон электрического типа",
"fighting": "Покемон боевого типа",
"psychic": "Покемон психического типа",
"colorless": "Покемон нормального типа",
"darkness": "Покемон темного типа",
"metal": "Покемон из стали типа",
"dragon": "Покемон типа дракона",
"fairy": "Покемон фея"
}
return word_map[english_string.lower()]
def generate_image(prompt):
if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness',
'metal', 'dragon', 'fairy']:
prompt = english_to_russian(prompt)
result, _ = generate_images(prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
buffer = BytesIO()
result[0].save(buffer, format="PNG")
base64_bytes = base64.b64encode(buffer.getvalue())
base64_string = base64_bytes.decode("UTF-8")
return "data:image/png;base64," + base64_string