yuntian-deng commited on
Commit
b8af58f
1 Parent(s): 92789a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -7
app.py CHANGED
@@ -7,6 +7,11 @@ import os
7
  from datetime import datetime
8
  from pytz import timezone
9
 
 
 
 
 
 
10
  tz = timezone('EST')
11
 
12
  API_ENDPOINT = os.getenv('API_ENDPOINT')
@@ -20,6 +25,43 @@ authors = "<center>Yuntian Deng, Noriyuki Kojima, Alexander M. Rush</center>"
20
  info = '<center><a href="https://openreview.net/pdf?id=81VJDmOE2ol">Paper</a> <a href="https://github.com/da03/markup2im">Code</a></center>'
21
  notice = "<p><center><strong>Notice:</strong> Due to resource constraints, we've transitioned from GPU to CPU processing for this demo, which results in significantly longer inference times. We appreciate your understanding.</center></p>"
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  with gr.Blocks() as demo:
24
  gr.Markdown(title)
25
  gr.Markdown(authors)
@@ -38,17 +80,31 @@ with gr.Blocks() as demo:
38
  current_time = datetime.now(tz)
39
  print (current_time, formula)
40
  data = {'formula': formula, 'api_key': API_KEY}
 
 
 
41
  try:
42
- with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r:
43
- i = 0
44
- for line in r.iter_lines():
45
- response = line.decode('ascii').strip()
46
- r = base64.decodebytes(response.encode('ascii'))
47
- q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3))
 
 
 
 
 
 
 
 
 
 
 
48
  i += 1
49
  yield i, q, submit_btn.update(visible=False)
50
  yield i, q, submit_btn.update(visible=True)
51
  except Exception as e:
52
  yield 1000, 255*np.ones((64, 320, 3)), submit_btn.update(visible=True)
53
  submit_btn.click(fn=infer, inputs=inputs, outputs=outputs)
54
- demo.queue(concurrency_count=20, max_size=200).launch(enable_queue=True)
 
7
  from datetime import datetime
8
  from pytz import timezone
9
 
10
+ import torch
11
+ import diffusers
12
+ from diffusers import DDPMPipeline
13
+ from transformers import AutoTokenizer, AutoModel
14
+
15
  tz = timezone('EST')
16
 
17
  API_ENDPOINT = os.getenv('API_ENDPOINT')
 
25
  info = '<center><a href="https://openreview.net/pdf?id=81VJDmOE2ol">Paper</a> <a href="https://github.com/da03/markup2im">Code</a></center>'
26
  notice = "<p><center><strong>Notice:</strong> Due to resource constraints, we've transitioned from GPU to CPU processing for this demo, which results in significantly longer inference times. We appreciate your understanding.</center></p>"
27
 
28
+
29
+ # setup
30
+ def setup(device='cuda'):
31
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
32
+ img_pipe = DDPMPipeline.from_pretrained("yuntian-deng/latex2im_ss_finetunegptneo")
33
+ img_pipe.to(device)
34
+
35
+ model_type = "EleutherAI/gpt-neo-125M"
36
+ #encoder = AutoModel.from_pretrained(model_type).to(device)
37
+ encoder = img_pipe.unet.text_encoder
38
+ if True:
39
+ l = len(img_pipe.unet.down_blocks)
40
+ for i in range(l):
41
+ img_pipe.unet.down_blocks[i] = torch.compile(img_pipe.unet.down_blocks[i])
42
+ l = len(img_pipe.unet.up_blocks)
43
+ for i in range(l):
44
+ img_pipe.unet.up_blocks[i] = torch.compile(img_pipe.unet.up_blocks[i])
45
+ tokenizer = AutoTokenizer.from_pretrained(model_type, max_length=1024)
46
+ eos_id = tokenizer.encode(tokenizer.eos_token)[0]
47
+
48
+ def forward_encoder(latex):
49
+ encoded = tokenizer(latex, return_tensors='pt', truncation=True, max_length=1024)
50
+ input_ids = encoded['input_ids']
51
+ input_ids = torch.cat((input_ids, torch.LongTensor([eos_id,]).unsqueeze(0)), dim=-1)
52
+ input_ids = input_ids.to(device)
53
+ attention_mask = encoded['attention_mask']
54
+ attention_mask = torch.cat((attention_mask, torch.LongTensor([1,]).unsqueeze(0)), dim=-1)
55
+ attention_mask = attention_mask.to(device)
56
+ with torch.no_grad():
57
+ outputs = encoder(input_ids=input_ids, attention_mask=attention_mask)
58
+ last_hidden_state = outputs.last_hidden_state
59
+ last_hidden_state = attention_mask.unsqueeze(-1) * last_hidden_state # shouldn't be necessary
60
+ return last_hidden_state
61
+ return img_pipe, forward_encoder
62
+
63
+ img_pipe, forward_encoder = setup()
64
+
65
  with gr.Blocks() as demo:
66
  gr.Markdown(title)
67
  gr.Markdown(authors)
 
80
  current_time = datetime.now(tz)
81
  print (current_time, formula)
82
  data = {'formula': formula, 'api_key': API_KEY}
83
+ latex = formula # TODO: normalize
84
+ encoder_hidden_states = forward_encoder(latex)
85
+
86
  try:
87
+ i = 0
88
+ results = []
89
+ for _, image_clean in img_pipe.run_clean(batch_size=1, generator=torch.manual_seed(0), encoder_hidden_states=encoder_hidden_states, output_type="numpy"):
90
+ i += 1
91
+ image_clean = image_clean[0]
92
+ image_clean = np.ascontiguousarray(image_clean)
93
+ #s = base64.b64encode(image_clean).decode('ascii')
94
+ #yield s
95
+ q = image_clean
96
+ yield i, q, submit_btn.update(visible=False)
97
+ yield i, q, submit_btn.update(visible=True)
98
+ #with requests.post(url=API_ENDPOINT, data=data, timeout=600, stream=True) as r:
99
+ # i = 0
100
+ # for line in r.iter_lines():
101
+ # response = line.decode('ascii').strip()
102
+ # r = base64.decodebytes(response.encode('ascii'))
103
+ # q = np.frombuffer(r, dtype=np.float32).reshape((64, 320, 3))
104
  i += 1
105
  yield i, q, submit_btn.update(visible=False)
106
  yield i, q, submit_btn.update(visible=True)
107
  except Exception as e:
108
  yield 1000, 255*np.ones((64, 320, 3)), submit_btn.update(visible=True)
109
  submit_btn.click(fn=infer, inputs=inputs, outputs=outputs)
110
+ demo.queue(concurrency_count=1, max_size=20).launch(enable_queue=True)