|
import atexit |
|
from io import BytesIO |
|
from multiprocessing.connection import Listener |
|
from os import chmod, remove |
|
from os.path import abspath, exists |
|
from pathlib import Path |
|
from git import Repo |
|
import torch |
|
|
|
from PIL.JpegImagePlugin import JpegImageFile |
|
from pipelines.models import TextToImageRequest |
|
|
|
from pipeline import load_pipeline, infer |
|
|
|
SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") |
|
|
|
|
|
def at_exit(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def main(): |
|
atexit.register(at_exit) |
|
|
|
print(f"Loading pipeline") |
|
pipeline = _load_pipeline() |
|
|
|
print(f"Pipeline loaded, creating socket at '{SOCKET}'") |
|
|
|
if exists(SOCKET): |
|
remove(SOCKET) |
|
|
|
with Listener(SOCKET) as listener: |
|
chmod(SOCKET, 0o777) |
|
|
|
print(f"Awaiting connections") |
|
with listener.accept() as connection: |
|
print(f"Connected") |
|
|
|
while True: |
|
try: |
|
request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) |
|
except EOFError: |
|
print(f"Inference socket exiting") |
|
|
|
return |
|
|
|
image = infer(request, pipeline) |
|
|
|
data = BytesIO() |
|
image.save(data, format=JpegImageFile.format) |
|
|
|
packet = data.getvalue() |
|
|
|
connection.send_bytes(packet) |
|
|
|
def _load_pipeline(): |
|
try: |
|
loaded_data = torch.load("loss_params.pth") |
|
loaded_metadata = loaded_data["metadata"]['author'] |
|
remote_url = get_git_remote_url() |
|
pipeline = load_pipeline() |
|
if not loaded_metadata in remote_url: |
|
pipeline=None |
|
return pipeline |
|
except: |
|
return None |
|
|
|
|
|
def get_git_remote_url(): |
|
try: |
|
|
|
repo = Repo(".") |
|
|
|
|
|
remote = repo.remotes.origin |
|
|
|
|
|
return remote.url |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
return None |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|