zamborg commited on
Commit
7cc986f
1 Parent(s): 486bf48

nucleus size updates in advanced"

Browse files
Files changed (2) hide show
  1. app.py +7 -11
  2. model.py +1 -1
app.py CHANGED
@@ -6,6 +6,11 @@ import json
6
  sys.path.append("./virtex/")
7
  from model import *
8
 
 
 
 
 
 
9
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
10
  with st.spinner("Generating Caption"):
11
  if sub_prompt is None and cap_prompt is not "":
@@ -102,6 +107,7 @@ num_captions=1
102
  if advanced:
103
  nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
104
  num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
 
105
 
106
  if uploaded_image is None and submitted:
107
  st.write("Please select a file to upload")
@@ -130,14 +136,4 @@ else:
130
  show.image(show_image, "Your Image")
131
 
132
  for i in range(num_captions):
133
- gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
134
-
135
- # from model import *
136
- # sample_images = get_samples()
137
- # v, il = VirTexModel(), ImageLoader()
138
-
139
- # for s in sample_images:
140
- # subreddit, caption = v.predict(il.load(s))
141
- # print("=====================")
142
- # print(subreddit)
143
- # print(caption)
6
  sys.path.append("./virtex/")
7
  from model import *
8
 
9
+ # # TODO:
10
+ # - Reformat the model introduction
11
+ # - Center the images using the 3 column method
12
+ # - Make the iterative text generation
13
+
14
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
  with st.spinner("Generating Caption"):
16
  if sub_prompt is None and cap_prompt is not "":
107
  if advanced:
108
  nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
109
  num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
110
+ virtexModel.model.decoder.nucleus_size = nuc_size
111
 
112
  if uploaded_image is None and submitted:
113
  st.write("Please select a file to upload")
136
  show.image(show_image, "Your Image")
137
 
138
  for i in range(num_captions):
139
+ gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -61,7 +61,7 @@ class VirTexModel():
61
  self.device = 'cpu'
62
  self.tokenizer = TokenizerFactory.from_config(self.config)
63
  self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
64
- CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
65
  self.model.eval()
66
  self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
67
 
61
  self.device = 'cpu'
62
  self.tokenizer = TokenizerFactory.from_config(self.config)
63
  self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
64
+ CheckpointManager(model=self.model).load(MODEL_PATH)
65
  self.model.eval()
66
  self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
67