Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
23a1215
1
Parent(s):
a294fb2
Use partial functions to allow pickling of prediction functions
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
|
|
|
2 |
from math import ceil, floor
|
3 |
import streamlit.components.v1 as components
|
4 |
from transformers import (
|
@@ -86,6 +87,16 @@ def download_classifier(classifier_args):
|
|
86 |
return True
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
@st.cache(persist=True, allow_output_mutation=True)
|
90 |
def load_predict(model_id):
|
91 |
model_info = MODELS[model_id]
|
@@ -102,16 +113,7 @@ def load_predict(model_id):
|
|
102 |
|
103 |
download_classifier(classifier_args)
|
104 |
|
105 |
-
|
106 |
-
if video_id not in prediction_cache[model_id]:
|
107 |
-
prediction_cache[model_id][video_id] = pred(
|
108 |
-
video_id, model, tokenizer,
|
109 |
-
segmentation_args=segmentation_args,
|
110 |
-
classifier_args=classifier_args
|
111 |
-
)
|
112 |
-
return prediction_cache[model_id][video_id]
|
113 |
-
|
114 |
-
return predict_function
|
115 |
|
116 |
|
117 |
def main():
|
@@ -192,5 +194,6 @@ def main():
|
|
192 |
wiki_link = '[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
|
193 |
st.markdown(wiki_link, unsafe_allow_html=True)
|
194 |
|
|
|
195 |
if __name__ == '__main__':
|
196 |
main()
|
|
|
1 |
|
2 |
+
from functools import partial
|
3 |
from math import ceil, floor
|
4 |
import streamlit.components.v1 as components
|
5 |
from transformers import (
|
|
|
87 |
return True
|
88 |
|
89 |
|
90 |
+
def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id):
|
91 |
+
if video_id not in prediction_cache[model_id]:
|
92 |
+
prediction_cache[model_id][video_id] = pred(
|
93 |
+
video_id, model, tokenizer,
|
94 |
+
segmentation_args=segmentation_args,
|
95 |
+
classifier_args=classifier_args
|
96 |
+
)
|
97 |
+
return prediction_cache[model_id][video_id]
|
98 |
+
|
99 |
+
|
100 |
@st.cache(persist=True, allow_output_mutation=True)
|
101 |
def load_predict(model_id):
|
102 |
model_info = MODELS[model_id]
|
|
|
113 |
|
114 |
download_classifier(classifier_args)
|
115 |
|
116 |
+
return partial(predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
|
119 |
def main():
|
|
|
194 |
wiki_link = '[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
|
195 |
st.markdown(wiki_link, unsafe_allow_html=True)
|
196 |
|
197 |
+
|
198 |
if __name__ == '__main__':
|
199 |
main()
|