Spaces:
Runtime error
Runtime error
cap_prompt attempts
Browse files
app.py
CHANGED
@@ -5,9 +5,9 @@ import time
|
|
5 |
import json
|
6 |
sys.path.append("./virtex/")
|
7 |
|
8 |
-
def gen_show_caption(sub_prompt=None, cap_prompt =
|
9 |
with st.spinner("Generating Caption"):
|
10 |
-
subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt)
|
11 |
st.header("Predicted Caption:\n\n")
|
12 |
st.subheader(f"r/{subreddit}:\t{caption}\n")
|
13 |
|
@@ -44,12 +44,9 @@ sub = st.sidebar.selectbox(
|
|
44 |
st.sidebar.title("Write a Custom Prompt")
|
45 |
cap_prompt = st.sidebar.text_input(
|
46 |
"Leave this blank for an unbiased caption",
|
47 |
-
value=
|
48 |
)
|
49 |
|
50 |
-
if cap_prompt is None:
|
51 |
-
st.write("HAHA")
|
52 |
-
|
53 |
if st.sidebar.button("Random Sample Image"):
|
54 |
random_image = get_rand_img(sample_images)
|
55 |
sample_image = None
|
|
|
5 |
import json
|
6 |
sys.path.append("./virtex/")
|
7 |
|
8 |
+
def gen_show_caption(sub_prompt=None, cap_prompt = ""):
|
9 |
with st.spinner("Generating Caption"):
|
10 |
+
subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
|
11 |
st.header("Predicted Caption:\n\n")
|
12 |
st.subheader(f"r/{subreddit}:\t{caption}\n")
|
13 |
|
|
|
44 |
st.sidebar.title("Write a Custom Prompt")
|
45 |
cap_prompt = st.sidebar.text_input(
|
46 |
"Leave this blank for an unbiased caption",
|
47 |
+
value=""
|
48 |
)
|
49 |
|
|
|
|
|
|
|
50 |
if st.sidebar.button("Random Sample Image"):
|
51 |
random_image = get_rand_img(sample_images)
|
52 |
sample_image = None
|
model.py
CHANGED
@@ -62,6 +62,9 @@ class VirTexModel():
|
|
62 |
)
|
63 |
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
|
64 |
|
|
|
|
|
|
|
65 |
|
66 |
predictions: List[Dict[str, Any]] = []
|
67 |
|
|
|
62 |
)
|
63 |
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
|
64 |
|
65 |
+
if prompt is not "":
|
66 |
+
cap_tokens = self.tokenizer.encode(subreddit_tokens)
|
67 |
+
subreddit_tokens = torch.cat((subreddit_tokens, cap_tokens))
|
68 |
|
69 |
predictions: List[Dict[str, Any]] = []
|
70 |
|