CHUNYU0505 commited on
Commit
4d5736b
·
verified ·
1 Parent(s): 255d19f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -61
app.py CHANGED
@@ -5,11 +5,10 @@
5
  import os, glob, requests
6
  from langchain.docstore.document import Document
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain.chains import RetrievalQA
9
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
10
  from docx import Document as DocxDocument
11
  import gradio as gr
12
- from langchain_community.vectorstores import FAISS
13
 
14
  # -------------------------------
15
  # 2. 環境變數與資料路徑
@@ -17,7 +16,7 @@ from langchain_community.vectorstores import FAISS
17
  TXT_FOLDER = "./out_texts"
18
  DB_PATH = "./faiss_db"
19
  os.makedirs(DB_PATH, exist_ok=True)
20
- os.makedirs(TXT_FOLDER, exist_ok=True) # 避免沒有 txt 檔時錯誤
21
 
22
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
23
  if not HF_TOKEN:
@@ -51,53 +50,47 @@ else:
51
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
52
 
53
  # -------------------------------
54
- # 4. LLM 設定(Hugging Face Endpoint)
55
  # -------------------------------
56
- llm = HuggingFaceEndpoint(
57
- repo_id="google/flan-t5-large",
58
- task="text2text-generation",
59
- huggingfacehub_api_token=HF_TOKEN,
60
- temperature=0.7,
61
- max_new_tokens=512,
62
- )
63
 
64
- qa_chain = RetrievalQA.from_chain_type(
65
- llm=llm,
66
- retriever=retriever,
67
- return_source_documents=True
68
- )
69
-
70
- # -------------------------------
71
- # 5. 檢查 Hugging Face Token 權限
72
- # -------------------------------
73
- def check_hf_token_permissions():
74
- """確認 Token 是否可呼叫 Inference Endpoint"""
75
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
 
 
 
 
 
 
 
 
 
 
76
  try:
77
- r = requests.get("https://huggingface.co/api/whoami-v2", headers=headers)
78
  r.raise_for_status()
79
  data = r.json()
80
- if "allow_inference" in data and data["allow_inference"]:
81
- return True
82
- return False
83
  except Exception:
84
- return False
85
-
86
- token_valid = check_hf_token_permissions()
87
- if not token_valid:
88
- print("⚠ 警告:Hugging Face API Token 權限不足,無法呼叫模型。")
89
-
90
 
91
  # -------------------------------
92
- # 6. 生成文章(修正版,支援進度顯示)
93
  # -------------------------------
94
- def generate_article_with_progress(query, segments=5):
95
- if not token_valid:
96
- # Token 權限不足,直接返回訊息
97
- yield "⚠ API Token 權限不足,請檢查 Token 是否允許呼叫 Inference Endpoint。", None
98
- return
99
-
100
- import time
101
  docx_file = "/tmp/generated_article.docx"
102
  doc = DocxDocument()
103
  doc.add_heading(query, level=1)
@@ -106,33 +99,24 @@ def generate_article_with_progress(query, segments=5):
106
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
107
 
108
  for i in range(int(segments)):
109
- try:
110
- result = qa_chain({"query": prompt})
111
- paragraph = result.get("result", "").strip()
112
- if not paragraph:
113
- paragraph = "(本段生成失敗,請嘗試減少段落或改用較小模型。)"
114
- except Exception as e:
115
- paragraph = f"(本段生成失敗:{e})"
116
-
117
  all_text.append(paragraph)
118
  doc.add_paragraph(paragraph)
119
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
120
 
 
121
  yield "\n\n".join(all_text), None
122
- time.sleep(0.1)
123
 
124
  doc.save(docx_file)
125
  rate_info = get_hf_rate_limit()
126
- final_text = f"{rate_info}\n\n" + "\n\n".join(all_text)
127
- yield final_text, docx_file
128
-
129
 
130
  # -------------------------------
131
- # 7. Gradio 介面(修正版)
132
  # -------------------------------
133
  with gr.Blocks() as demo:
134
  gr.Markdown("# 佛教經論 RAG 系統 (HF API)")
135
- gr.Markdown("使用 Hugging Face Endpoint LLM + FAISS RAG,生成文章並提示 API 剩餘額度。")
136
 
137
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
138
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
@@ -140,17 +124,14 @@ with gr.Blocks() as demo:
140
  output_file = gr.File(label="下載 DOCX")
141
 
142
  btn = gr.Button("生成文章")
143
-
144
- # 使用 .click() 搭配 generator
145
  btn.click(
146
- fn=generate_article_with_progress,
147
  inputs=[query_input, segments_input],
148
  outputs=[output_text, output_file]
149
  )
150
 
151
  # -------------------------------
152
- # 8. 啟動 Gradio(HF Space 適用)
153
  # -------------------------------
154
  if __name__ == "__main__":
155
  demo.launch()
156
-
 
5
  import os, glob, requests
6
  from langchain.docstore.document import Document
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
  from docx import Document as DocxDocument
11
  import gradio as gr
 
12
 
13
  # -------------------------------
14
  # 2. 環境變數與資料路徑
 
16
  TXT_FOLDER = "./out_texts"
17
  DB_PATH = "./faiss_db"
18
  os.makedirs(DB_PATH, exist_ok=True)
19
+ os.makedirs(TXT_FOLDER, exist_ok=True)
20
 
21
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
22
  if not HF_TOKEN:
 
50
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
51
 
52
  # -------------------------------
53
+ # 4. 定義 REST API 呼叫函數
54
  # -------------------------------
55
+ INFERENCE_MODEL = "google/flan-t5-large"
56
+ API_URL = f"https://api-inference.huggingface.co/models/{INFERENCE_MODEL}"
57
+ HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
 
 
 
 
58
 
59
+ def call_hf_inference(prompt, max_new_tokens=512):
60
+ payload = {
61
+ "inputs": prompt,
62
+ "parameters": {"max_new_tokens": max_new_tokens}
63
+ }
64
+ try:
65
+ response = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
66
+ response.raise_for_status()
67
+ data = response.json()
68
+ if isinstance(data, list) and "generated_text" in data[0]:
69
+ return data[0]["generated_text"]
70
+ elif isinstance(data, dict) and "error" in data:
71
+ return f"(生成失敗:{data['error']})"
72
+ else:
73
+ return str(data)
74
+ except Exception as e:
75
+ return f"(生成失敗:{e})"
76
+
77
+ # -------------------------------
78
+ # 5. 查詢 API 剩餘額度
79
+ # -------------------------------
80
+ def get_hf_rate_limit():
81
  try:
82
+ r = requests.get("https://huggingface.co/api/whoami", headers=HEADERS)
83
  r.raise_for_status()
84
  data = r.json()
85
+ remaining = data.get("rate_limit", {}).get("remaining", "未知")
86
+ return f"本小時剩餘 API 次數:約 {remaining}"
 
87
  except Exception:
88
+ return "無法取得 API 速率資訊"
 
 
 
 
 
89
 
90
  # -------------------------------
91
+ # 6. 生成文章(即時進度)
92
  # -------------------------------
93
+ def generate_article_progress(query, segments=5):
 
 
 
 
 
 
94
  docx_file = "/tmp/generated_article.docx"
95
  doc = DocxDocument()
96
  doc.add_heading(query, level=1)
 
99
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
100
 
101
  for i in range(int(segments)):
102
+ paragraph = call_hf_inference(prompt)
 
 
 
 
 
 
 
103
  all_text.append(paragraph)
104
  doc.add_paragraph(paragraph)
105
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
106
 
107
+ # yield 即時更新 Textbox
108
  yield "\n\n".join(all_text), None
 
109
 
110
  doc.save(docx_file)
111
  rate_info = get_hf_rate_limit()
112
+ yield f"{rate_info}\n\n" + "\n\n".join(all_text), docx_file
 
 
113
 
114
  # -------------------------------
115
+ # 7. Gradio 介面
116
  # -------------------------------
117
  with gr.Blocks() as demo:
118
  gr.Markdown("# 佛教經論 RAG 系統 (HF API)")
119
+ gr.Markdown("使用 Hugging Face REST API + FAISS RAG,生成文章並提示 API 剩餘額度。")
120
 
121
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
122
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
 
124
  output_file = gr.File(label="下載 DOCX")
125
 
126
  btn = gr.Button("生成文章")
 
 
127
  btn.click(
128
+ generate_article_progress,
129
  inputs=[query_input, segments_input],
130
  outputs=[output_text, output_file]
131
  )
132
 
133
  # -------------------------------
134
+ # 8. 啟動 Gradio(Hugging Face Space 適用)
135
  # -------------------------------
136
  if __name__ == "__main__":
137
  demo.launch()