Spaces:
Runtime error
Runtime error
Commit
·
b5bd188
1
Parent(s):
69e32d1
Fix prediction function
Browse files
app.py
CHANGED
|
@@ -21,7 +21,7 @@ def load_model(ckpt):
|
|
| 21 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
| 22 |
|
| 23 |
@st.cache(persist=True)
|
| 24 |
-
def predict(
|
| 25 |
return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
|
| 26 |
|
| 27 |
def softmax(logits):
|
|
@@ -125,7 +125,7 @@ state.answer_lang_id = col2.selectbox('Answer Language', index=options.index(sta
|
|
| 125 |
with st.spinner('Loading model...'):
|
| 126 |
model = load_model(checkpoints[0])
|
| 127 |
with st.spinner('Predicting...'):
|
| 128 |
-
logits = predict(
|
| 129 |
logits = softmax(logits)
|
| 130 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
| 131 |
translated_labels = translate_labels(labels, state.answer_lang_id)
|
|
|
|
| 21 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
| 22 |
|
| 23 |
@st.cache(persist=True)
|
| 24 |
+
def predict(transformed_image, question_inputs):
|
| 25 |
return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
|
| 26 |
|
| 27 |
def softmax(logits):
|
|
|
|
| 125 |
with st.spinner('Loading model...'):
|
| 126 |
model = load_model(checkpoints[0])
|
| 127 |
with st.spinner('Predicting...'):
|
| 128 |
+
logits = predict(transformed_image, dict(question_inputs))
|
| 129 |
logits = softmax(logits)
|
| 130 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
| 131 |
translated_labels = translate_labels(labels, state.answer_lang_id)
|