keithhon commited on
Commit
7c100db
1 Parent(s): e27e6cf

Upload server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server.py +72 -0
server.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import base64
4
+ from io import BytesIO
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+
7
+ import torch
8
+ from torch import nn
9
+ from fastapi import FastAPI
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ from dalle.models import Dalle
14
+ import logging
15
+ import streamlit as st
16
+
17
+
18
+ print("Loading models...")
19
+ app = FastAPI()
20
+
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ logging.info("Start downloading")
24
+ full_dict_path = hf_hub_download(repo_id="ml6team/logo-generator", filename="full_dict_new.ckpt",
25
+ use_auth_token=st.secrets["model_download"])
26
+ logging.info("End downloading")
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ model = Dalle.from_pretrained("minDALL-E/1.3B")
30
+
31
+ model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu')))
32
+ model.to(device=device)
33
+
34
+ print("Models loaded !")
35
+
36
+
37
+ @app.get("/")
38
+ def read_root():
39
+ return {"minDALL-E!"}
40
+
41
+
42
+ @app.get("/{generate}")
43
+ def generate(prompt):
44
+ images = sample(prompt)
45
+ images = [to_base64(image) for image in images]
46
+ return {"images": images}
47
+
48
+
49
+ def sample(prompt):
50
+ # Sampling
51
+ logging.info("starting sampling")
52
+ images = (
53
+ model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
54
+ .cpu()
55
+ .numpy()
56
+ )
57
+ logging.info("sampling succeeded")
58
+ images = np.transpose(images, (0, 2, 3, 1))
59
+
60
+
61
+ pil_images = []
62
+ for i in range(len(images)):
63
+ im = Image.fromarray((images[i] * 255).astype(np.uint8))
64
+ pil_images.append(im)
65
+
66
+ return pil_images
67
+
68
+
69
+ def to_base64(pil_image):
70
+ buffered = BytesIO()
71
+ pil_image.save(buffered, format="JPEG")
72
+ return base64.b64encode(buffered.getvalue())