AdamNovotnyCom commited on
Commit
15a1b0b
1 Parent(s): f0a60ae

multiple models

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -2
  2. Dockerfile_dev +0 -2
  3. app.py +3 -8
Dockerfile CHANGED
@@ -20,8 +20,6 @@ RUN pip install -r requirements.txt
20
 
21
  EXPOSE 7860
22
 
23
- ENV MODEL=llama
24
-
25
  RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true
26
 
27
  CMD ["python", "app.py"]
 
20
 
21
  EXPOSE 7860
22
 
 
 
23
  RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true
24
 
25
  CMD ["python", "app.py"]
Dockerfile_dev CHANGED
@@ -20,7 +20,5 @@ RUN pip install -r requirements.txt
20
 
21
  EXPOSE 7860
22
 
23
- ENV MODEL=googleflan
24
-
25
  # with reload
26
  CMD ["gradio", "app.py"]
 
20
 
21
  EXPOSE 7860
22
 
 
 
23
  # with reload
24
  CMD ["gradio", "app.py"]
app.py CHANGED
@@ -6,20 +6,15 @@ import transformers
6
  from transformers import AutoTokenizer
7
 
8
  logging.basicConfig(level=logging.INFO)
 
9
 
10
  if "googleflan" == os.environ.get("MODEL"):
11
  model = "google/flan-t5-small"
12
- logging.info(f"APP startup. Model {model}")
13
  pipe_flan = transformers.pipeline("text2text-generation", model=model)
14
  def model_func(input_text, request: gr.Request):
15
- print(f"Input request: {input_text}")
16
- print(request.query_params)
17
- print(os.environ.get("HF_TOKEN")[:5])
18
- logging.info(os.environ.get("HF_TOKEN")[:5])
19
  return pipe_flan(input_text)
20
  elif "llama" == os.environ.get("MODEL"):
21
  model = "meta-llama/Llama-2-7b-chat-hf"
22
- logging.info(f"APP startup. Model {model}")
23
  tokenizer = AutoTokenizer.from_pretrained(
24
  model,
25
  token=os.environ.get("HF_TOKEN"),
@@ -27,7 +22,7 @@ elif "llama" == os.environ.get("MODEL"):
27
  pipeline = transformers.pipeline(
28
  "text-generation",
29
  model=model,
30
- torch_dtype=torch.float16,
31
  device_map="auto",
32
  token=os.environ.get("HF_TOKEN"),
33
  )
@@ -57,7 +52,7 @@ demo = gr.Interface(
57
  value="",
58
  ),
59
  outputs=gr.Textbox(
60
- label="LLM",
61
  lines=5,
62
  value="",
63
  ),
 
6
  from transformers import AutoTokenizer
7
 
8
  logging.basicConfig(level=logging.INFO)
9
+ logging.info(f"APP startup")
10
 
11
  if "googleflan" == os.environ.get("MODEL"):
12
  model = "google/flan-t5-small"
 
13
  pipe_flan = transformers.pipeline("text2text-generation", model=model)
14
  def model_func(input_text, request: gr.Request):
 
 
 
 
15
  return pipe_flan(input_text)
16
  elif "llama" == os.environ.get("MODEL"):
17
  model = "meta-llama/Llama-2-7b-chat-hf"
 
18
  tokenizer = AutoTokenizer.from_pretrained(
19
  model,
20
  token=os.environ.get("HF_TOKEN"),
 
22
  pipeline = transformers.pipeline(
23
  "text-generation",
24
  model=model,
25
+ torch_dtype=torch.float32,
26
  device_map="auto",
27
  token=os.environ.get("HF_TOKEN"),
28
  )
 
52
  value="",
53
  ),
54
  outputs=gr.Textbox(
55
+ label=f"Model: {os.environ.get('MODEL')}",
56
  lines=5,
57
  value="",
58
  ),