zamborg commited on
Commit
0674c7e
1 Parent(s): 3408c6b

reorganized and hopefully fixed torch.cat

Browse files
Files changed (2) hide show
  1. app.py +8 -5
  2. model.py +3 -2
app.py CHANGED
@@ -28,11 +28,18 @@ with st.spinner("Loading Model"):
28
  valid_subs.insert(0, None)
29
 
30
  random_image = get_rand_img(sample_images)
 
31
 
32
  st.sidebar.title("Select a sample image")
 
 
 
 
 
33
  sample_image = st.sidebar.selectbox(
34
  "",
35
- sample_images
 
36
  )
37
 
38
  st.sidebar.title("Select a Subreddit")
@@ -46,10 +53,6 @@ cap_prompt = st.sidebar.text_input(
46
  "Leave this blank for an unbiased caption",
47
  value=""
48
  )
49
-
50
- if st.sidebar.button("Random Sample Image"):
51
- random_image = get_rand_img(sample_images)
52
- sample_image = None
53
 
54
 
55
  uploaded_image = None
 
28
  valid_subs.insert(0, None)
29
 
30
  random_image = get_rand_img(sample_images)
31
+ rand_idx = 0
32
 
33
  st.sidebar.title("Select a sample image")
34
+
35
+ if st.sidebar.button("Random Sample Image"):
36
+ rand_idx, random_image = get_rand_img(sample_images)
37
+ sample_image = None
38
+
39
  sample_image = st.sidebar.selectbox(
40
  "",
41
+ sample_images,
42
+ index=rand_idx
43
  )
44
 
45
  st.sidebar.title("Select a Subreddit")
 
53
  "Leave this blank for an unbiased caption",
54
  value=""
55
  )
 
 
 
 
56
 
57
 
58
  uploaded_image = None
model.py CHANGED
@@ -64,7 +64,7 @@ class VirTexModel():
64
 
65
  if prompt is not "":
66
  cap_tokens = self.tokenizer.encode(prompt)
67
- subreddit_tokens = torch.cat((subreddit_tokens, cap_tokens))
68
 
69
  predictions: List[Dict[str, Any]] = []
70
 
@@ -105,5 +105,6 @@ def get_samples():
105
  return glob.glob(SAMPLES_PATH)
106
 
107
  def get_rand_img(samples):
108
- return samples[random.randint(0,len(samples)-1)]
 
109
 
 
64
 
65
  if prompt is not "":
66
  cap_tokens = self.tokenizer.encode(prompt)
67
+ subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
68
 
69
  predictions: List[Dict[str, Any]] = []
70
 
 
105
  return glob.glob(SAMPLES_PATH)
106
 
107
  def get_rand_img(samples):
108
+ i = random.randint(0,len(samples)-1)
109
+ return i, samples[i]
110