Joshua Lochner commited on
Commit
9a5d9ed
1 Parent(s): e926596

Use `get_model_tokenizer` method from streamlit app

Browse files
Files changed (1) hide show
  1. app.py +2 -9
app.py CHANGED
@@ -2,10 +2,6 @@
2
  from functools import partial
3
  from math import ceil, floor
4
  import streamlit.components.v1 as components
5
- from transformers import (
6
- AutoModelForSeq2SeqLM,
7
- AutoTokenizer,
8
- )
9
  import streamlit as st
10
  import sys
11
  import os
@@ -20,6 +16,7 @@ from predict import SegmentationArguments, ClassifierArguments, predict as pred,
20
  from evaluate import EvaluationArguments
21
  from shared import device, CATGEGORY_OPTIONS
22
  from utils import regex_search
 
23
 
24
  st.set_page_config(
25
  page_title='SponsorBlock ML',
@@ -144,11 +141,7 @@ def load_predict(model_id):
144
  segmentation_args = SegmentationArguments()
145
  classifier_args = ClassifierArguments()
146
 
147
- model = AutoModelForSeq2SeqLM.from_pretrained(
148
- evaluation_args.model_path)
149
- model.to(device())
150
-
151
- tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
152
 
153
  download_classifier(classifier_args)
154
 
 
2
  from functools import partial
3
  from math import ceil, floor
4
  import streamlit.components.v1 as components
 
 
 
 
5
  import streamlit as st
6
  import sys
7
  import os
 
16
  from evaluate import EvaluationArguments
17
  from shared import device, CATGEGORY_OPTIONS
18
  from utils import regex_search
19
+ from model import get_model_tokenizer
20
 
21
  st.set_page_config(
22
  page_title='SponsorBlock ML',
 
141
  segmentation_args = SegmentationArguments()
142
  classifier_args = ClassifierArguments()
143
 
144
+ model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
 
 
 
 
145
 
146
  download_classifier(classifier_args)
147