yonikremer commited on
Commit
7b34e37
1 Parent(s): 2fd3831

models will be downloaded when the server starts

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. on_server_start.py +37 -0
app.py CHANGED
@@ -7,6 +7,7 @@ In the demo, the user can write a prompt
7
  import streamlit as st
8
 
9
  from hanlde_form_submit import on_form_submit
 
10
 
11
 
12
  AVAILABLE_MODEL_NAMES = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
 
7
  import streamlit as st
8
 
9
  from hanlde_form_submit import on_form_submit
10
+ from on_server_start import main as on_server_start_main
11
 
12
 
13
  AVAILABLE_MODEL_NAMES = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
on_server_start.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A script that is run when the server starts.
3
+ """
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+
8
+ def download_model(model_name: str):
9
+ """
10
+ Downloads the model with the given name.
11
+ :param model_name: The name of the model to download.
12
+ """
13
+ AutoModelForCausalLM.from_pretrained(model_name)
14
+ AutoTokenizer.from_pretrained(model_name)
15
+
16
+
17
+ def download_useful_models():
18
+ """
19
+ Downloads the models that are useful for this project.
20
+ So that the user doesn't have to wait for the models to download when they first use the app.
21
+ """
22
+ useful_models = (
23
+ "gpt2",
24
+ "EleutherAI/gpt-j-6B",
25
+ "sberbank-ai/mGPT",
26
+ "facebook/opt-125m",
27
+ )
28
+ for model_name in useful_models:
29
+ download_model(model_name)
30
+
31
+
32
+ def main():
33
+ download_useful_models()
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()