Spaces:
Runtime error
Runtime error
reorganized and hopefully fixed torch.cat
Browse files
app.py
CHANGED
@@ -28,11 +28,18 @@ with st.spinner("Loading Model"):
|
|
28 |
valid_subs.insert(0, None)
|
29 |
|
30 |
random_image = get_rand_img(sample_images)
|
|
|
31 |
|
32 |
st.sidebar.title("Select a sample image")
|
|
|
|
|
|
|
|
|
|
|
33 |
sample_image = st.sidebar.selectbox(
|
34 |
"",
|
35 |
-
sample_images
|
|
|
36 |
)
|
37 |
|
38 |
st.sidebar.title("Select a Subreddit")
|
@@ -46,10 +53,6 @@ cap_prompt = st.sidebar.text_input(
|
|
46 |
"Leave this blank for an unbiased caption",
|
47 |
value=""
|
48 |
)
|
49 |
-
|
50 |
-
if st.sidebar.button("Random Sample Image"):
|
51 |
-
random_image = get_rand_img(sample_images)
|
52 |
-
sample_image = None
|
53 |
|
54 |
|
55 |
uploaded_image = None
|
|
|
28 |
valid_subs.insert(0, None)
|
29 |
|
30 |
random_image = get_rand_img(sample_images)
|
31 |
+
rand_idx = 0
|
32 |
|
33 |
st.sidebar.title("Select a sample image")
|
34 |
+
|
35 |
+
if st.sidebar.button("Random Sample Image"):
|
36 |
+
rand_idx, random_image = get_rand_img(sample_images)
|
37 |
+
sample_image = None
|
38 |
+
|
39 |
sample_image = st.sidebar.selectbox(
|
40 |
"",
|
41 |
+
sample_images,
|
42 |
+
index=rand_idx
|
43 |
)
|
44 |
|
45 |
st.sidebar.title("Select a Subreddit")
|
|
|
53 |
"Leave this blank for an unbiased caption",
|
54 |
value=""
|
55 |
)
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
uploaded_image = None
|
model.py
CHANGED
@@ -64,7 +64,7 @@ class VirTexModel():
|
|
64 |
|
65 |
if prompt is not "":
|
66 |
cap_tokens = self.tokenizer.encode(prompt)
|
67 |
-
subreddit_tokens = torch.cat(
|
68 |
|
69 |
predictions: List[Dict[str, Any]] = []
|
70 |
|
@@ -105,5 +105,6 @@ def get_samples():
|
|
105 |
return glob.glob(SAMPLES_PATH)
|
106 |
|
107 |
def get_rand_img(samples):
|
108 |
-
|
|
|
109 |
|
|
|
64 |
|
65 |
if prompt is not "":
|
66 |
cap_tokens = self.tokenizer.encode(prompt)
|
67 |
+
subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
|
68 |
|
69 |
predictions: List[Dict[str, Any]] = []
|
70 |
|
|
|
105 |
return glob.glob(SAMPLES_PATH)
|
106 |
|
107 |
def get_rand_img(samples):
|
108 |
+
i = random.randint(0,len(samples)-1)
|
109 |
+
return i, samples[i]
|
110 |
|