masanorihirano
commited on
Commit
•
1e5d5c7
1
Parent(s):
7680f1c
update
Browse files- Dockerfile +7 -2
- app.py +7 -7
Dockerfile
CHANGED
@@ -15,9 +15,7 @@ RUN sed -i 's http://deb.debian.org http://cdn-aws.deb.debian.org g' /etc/apt/so
|
|
15 |
rm -rf /var/lib/apt/lists/* && \
|
16 |
git lfs install
|
17 |
|
18 |
-
COPY --link --chown=1000 ./ /home/user/app
|
19 |
RUN useradd -m -u 1000 user
|
20 |
-
WORKDIR /home/user/app
|
21 |
USER user
|
22 |
|
23 |
RUN curl https://pyenv.run | bash
|
@@ -33,10 +31,17 @@ RUN eval "$(pyenv init -)" && \
|
|
33 |
curl -sSL https://install.python-poetry.org | python -
|
34 |
ENV PATH /home/user/.local/bin:${PATH}
|
35 |
|
|
|
|
|
|
|
|
|
36 |
RUN poetry install
|
37 |
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
38 |
git config --global credential.helper store && \
|
39 |
huggingface-cli login --token $(cat /run/secrets/HF_TOKEN) --add-to-git-credential
|
40 |
RUN poetry run python model_pull.py
|
|
|
|
|
|
|
41 |
EXPOSE 7860
|
42 |
ENTRYPOINT ["/home/user/.local/bin/poetry", "run", "python", "app.py", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
15 |
rm -rf /var/lib/apt/lists/* && \
|
16 |
git lfs install
|
17 |
|
|
|
18 |
RUN useradd -m -u 1000 user
|
|
|
19 |
USER user
|
20 |
|
21 |
RUN curl https://pyenv.run | bash
|
|
|
31 |
curl -sSL https://install.python-poetry.org | python -
|
32 |
ENV PATH /home/user/.local/bin:${PATH}
|
33 |
|
34 |
+
COPY --link --chown=1000 ./pyproject.toml /home/user/app/pyproject.toml
|
35 |
+
COPY --link --chown=1000 ./model_pull.py /home/user/app/model_pull.py
|
36 |
+
WORKDIR /home/user/app
|
37 |
+
|
38 |
RUN poetry install
|
39 |
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
40 |
git config --global credential.helper store && \
|
41 |
huggingface-cli login --token $(cat /run/secrets/HF_TOKEN) --add-to-git-credential
|
42 |
RUN poetry run python model_pull.py
|
43 |
+
|
44 |
+
COPY --link --chown=1000 ./app.py /home/user/app/app.py
|
45 |
+
|
46 |
EXPOSE 7860
|
47 |
ENTRYPOINT ["/home/user/.local/bin/poetry", "run", "python", "app.py", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
CHANGED
@@ -29,31 +29,31 @@ except Exception:
|
|
29 |
if device == "cuda":
|
30 |
model = LlamaForCausalLM.from_pretrained(
|
31 |
BASE_MODEL,
|
32 |
-
load_in_8bit=
|
33 |
-
torch_dtype=torch.float16,
|
34 |
device_map="auto",
|
35 |
)
|
36 |
-
model = PeftModel.from_pretrained(model, LORA_WEIGHTS,
|
37 |
elif device == "mps":
|
38 |
model = LlamaForCausalLM.from_pretrained(
|
39 |
BASE_MODEL,
|
40 |
device_map={"": device},
|
41 |
-
|
42 |
)
|
43 |
model = PeftModel.from_pretrained(
|
44 |
model,
|
45 |
LORA_WEIGHTS,
|
46 |
device_map={"": device},
|
47 |
-
|
48 |
)
|
49 |
else:
|
50 |
model = LlamaForCausalLM.from_pretrained(
|
51 |
-
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
|
52 |
)
|
53 |
model = PeftModel.from_pretrained(
|
54 |
model,
|
55 |
LORA_WEIGHTS,
|
56 |
device_map={"": device},
|
|
|
57 |
)
|
58 |
|
59 |
|
@@ -136,4 +136,4 @@ g = gr.Interface(
|
|
136 |
)
|
137 |
g.queue(concurrency_count=1)
|
138 |
print("loading completed")
|
139 |
-
g.launch()
|
|
|
29 |
if device == "cuda":
|
30 |
model = LlamaForCausalLM.from_pretrained(
|
31 |
BASE_MODEL,
|
32 |
+
load_in_8bit=True,
|
|
|
33 |
device_map="auto",
|
34 |
)
|
35 |
+
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True)
|
36 |
elif device == "mps":
|
37 |
model = LlamaForCausalLM.from_pretrained(
|
38 |
BASE_MODEL,
|
39 |
device_map={"": device},
|
40 |
+
load_in_8bit=True
|
41 |
)
|
42 |
model = PeftModel.from_pretrained(
|
43 |
model,
|
44 |
LORA_WEIGHTS,
|
45 |
device_map={"": device},
|
46 |
+
load_in_8bit=True
|
47 |
)
|
48 |
else:
|
49 |
model = LlamaForCausalLM.from_pretrained(
|
50 |
+
BASE_MODEL, device_map={"": device}, load_in_8bit=True, low_cpu_mem_usage=True
|
51 |
)
|
52 |
model = PeftModel.from_pretrained(
|
53 |
model,
|
54 |
LORA_WEIGHTS,
|
55 |
device_map={"": device},
|
56 |
+
load_in_8bit=True
|
57 |
)
|
58 |
|
59 |
|
|
|
136 |
)
|
137 |
g.queue(concurrency_count=1)
|
138 |
print("loading completed")
|
139 |
+
g.launch(server_name="0.0.0.0", server_port=7860)
|