ApolloPro7 commited on
Commit
95cbdfb
·
1 Parent(s): 8c890b0

HuggingFace Space deployment

Browse files
Files changed (5) hide show
  1. Dockerfile +10 -14
  2. app.py +47 -0
  3. requirements.txt +44 -2
  4. start.sh +5 -0
  5. streamlit_app.py +18 -0
Dockerfile CHANGED
@@ -1,21 +1,17 @@
 
1
  FROM python:3.9-slim
2
 
 
3
  WORKDIR /app
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
 
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
14
 
15
- RUN pip3 install -r requirements.txt
 
16
 
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ # 使用官方 Python 镜像
2
  FROM python:3.9-slim
3
 
4
+ # 设置工作目录
5
  WORKDIR /app
6
 
7
+ # 拷贝依赖和代码
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
 
 
 
10
 
11
+ COPY . .
 
12
 
13
+ # 设置默认端口
14
+ ENV PORT=7860
15
 
16
+ # 启动 FastAPI 服务
17
+ CMD ["./start.sh"]
 
 
 
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration
3
+ from peft import PeftModel
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from huggingface_hub import login
7
+
8
+ login(token=os.getenv("HF_TOKEN"))
9
+ print("Hugging Face Successfully Login!")
10
+
11
+ app = FastAPI()
12
+
13
+ # Load fine-tuned model and tokenizer
14
+ # tokenizer = AutoTokenizer.from_pretrained("./llama2-7b", local_files_only=True)
15
+ # base_model = AutoModelForCausalLM.from_pretrained("./llama2-7b", local_files_only=True)
16
+ # model = PeftModel.from_pretrained(base_model, "./checkpoint-5400", local_files_only=True)
17
+
18
+
19
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
20
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
21
+
22
+ # Define data structure of parameters
23
+ class PromptInput(BaseModel):
24
+ prompt: str
25
+
26
+ # define API interface
27
+ # @app.post("/generate")
28
+ # def generate_script(input: PromptInput):
29
+ # print("Starts Generating!")
30
+ # inputs = tokenizer(input.prompt, return_tensors="pt")
31
+ # print("Inputs Tokenized! Generating Begins~")
32
+ # outputs = model.generate(**inputs, max_new_tokens=200)
33
+ # print("Generating Succeed!")
34
+ # result = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+ # print("Results formed!")
36
+ # return {"generated_script": result}
37
+
38
+ @app.post("/generate")
39
+ def generate_script(input: PromptInput):
40
+ print("Starts Generating!")
41
+ inputs = tokenizer(input.prompt, return_tensors="pt").input_ids
42
+ print("Inputs Tokenized! Generating Begins~")
43
+ outputs = model.generate(inputs)
44
+ print("Generating Succeed!")
45
+ result = tokenizer.decode(outputs[0])
46
+ print("Results formed!")
47
+ return {"generated_script": result}
requirements.txt CHANGED
@@ -1,3 +1,45 @@
1
- altair
2
- pandas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  streamlit
 
1
+ accelerate==1.6.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ certifi==2025.4.26
5
+ charset-normalizer==3.4.2
6
+ click==8.1.8
7
+ exceptiongroup==1.2.2
8
+ fastapi==0.115.12
9
+ filelock==3.18.0
10
+ fsspec==2025.3.2
11
+ h11==0.16.0
12
+ httptools==0.6.4
13
+ huggingface-hub==0.30.2
14
+ idna==3.10
15
+ Jinja2==3.1.6
16
+ MarkupSafe==3.0.2
17
+ mpmath==1.3.0
18
+ networkx==3.2.1
19
+ numpy==2.0.2
20
+ packaging==25.0
21
+ peft==0.15.2
22
+ psutil==7.0.0
23
+ pydantic==2.11.4
24
+ pydantic_core==2.33.2
25
+ python-dotenv==1.1.0
26
+ PyYAML==6.0.2
27
+ regex==2024.11.6
28
+ requests==2.32.3
29
+ safetensors==0.5.3
30
+ sniffio==1.3.1
31
+ starlette==0.46.2
32
+ sympy==1.14.0
33
+ tokenizers==0.21.1
34
+ torch==2.7.0
35
+ tqdm==4.67.1
36
+ transformers==4.51.3
37
+ typing-inspection==0.4.0
38
+ typing_extensions==4.13.2
39
+ urllib3==2.4.0
40
+ uvicorn==0.34.2
41
+ uvloop==0.21.0
42
+ watchfiles==1.0.5
43
+ websockets==15.0.1
44
+ sentencepiece
45
  streamlit
start.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/zsh
2
+
3
+ uvicorn app:app --host 0.0.0.0 --port 8000 &
4
+
5
+ streamlit run ./streamlit_app.py --server.port 7860 --server.enableCORS false
streamlit_app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+
4
+ st.title("LLaMA2 TV Script Generator")
5
+ prompt = st.text_area("Enter your prompt:")
6
+
7
+ if st.button("Generate"):
8
+ with st.spinner("Generating..."):
9
+ response = requests.post(
10
+ "http://localhost:8000/generate",
11
+ json={"prompt": prompt}
12
+ )
13
+
14
+ if response.ok:
15
+ st.markdown("### Output")
16
+ st.write(response.json()["generated_script"])
17
+ else:
18
+ st.error("Something went wrong.")