ChatBotsTA commited on
Commit
8945838
·
verified ·
1 Parent(s): f33a5dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -101
app.py CHANGED
@@ -7,135 +7,199 @@ from huggingface_hub import InferenceClient
7
  import pdfplumber
8
  from PIL import Image
9
  import base64
10
-
11
- # ---------- Configuration ----------
12
- HF_TOKEN = os.environ.get("HF_TOKEN") # required
13
- GROQ_KEY = os.environ.get("GROQ_API_KEY") # optional: if you want to call Groq directly
14
- USE_GROQ_PROVIDER = True # set False to route to default HF provider
15
-
16
- # model IDs (change if you prefer other models)
17
- LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF
18
- TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # a HF-hosted TTS model example
19
- SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL base model
20
-
21
- # create Inference client (route via HF token by default)
22
- if USE_GROQ_PROVIDER:
23
- client = InferenceClient(provider="groq", api_key=HF_TOKEN)
24
- else:
25
- client = InferenceClient(api_key=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # ---------- Helpers ----------
28
- def pdf_to_text(uploaded_file) -> str:
29
  text_chunks = []
30
- with pdfplumber.open(uploaded_file) as pdf:
31
  for page in pdf.pages:
32
  ptext = page.extract_text()
33
  if ptext:
34
  text_chunks.append(ptext)
35
  return "\n\n".join(text_chunks)
36
 
37
- def llama_summarize(text, max_tokens=512):
38
- prompt = [
39
- {"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."},
40
- {"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"}
 
 
 
41
  ]
42
- # Use chat completion endpoint style
43
- resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt)
44
  try:
45
- summary = resp.choices[0].message["content"]
 
46
  except Exception:
47
- # fallback: try text generation field
48
- summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp)
49
- return summary
50
-
51
- def llama_chat(chat_history, user_question):
52
- messages = chat_history + [{"role":"user","content":user_question}]
53
- resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
54
- return resp.choices[0].message["content"]
55
-
56
- def tts_synthesize(text) -> bytes:
57
- # InferenceClient offers text->audio utilities. This returns raw audio bytes (wav).
58
- audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
59
- return audio_bytes
60
-
61
- def generate_image(prompt_text) -> Image.Image:
62
- img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
63
- return Image.open(io.BytesIO(img_bytes))
64
-
65
- def audio_download_button(wav_bytes, filename="summary.wav"):
66
- b64 = base64.b64encode(wav_bytes).decode()
67
- href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>'
68
- st.markdown(href, unsafe_allow_html=True)
69
-
70
- # ---------- Streamlit UI ----------
71
- st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide")
72
- st.title("PDF Summary + Speech + Chat + Diagram (Groq + HF)")
73
-
74
- uploaded = st.file_uploader("Upload PDF", type=["pdf"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if uploaded:
 
76
  with st.spinner("Extracting text from PDF..."):
77
- text = pdf_to_text(uploaded)
78
- st.subheader("Extracted text (preview)")
79
- st.text_area("Document text", value=text[:1000], height=200)
80
-
81
- if st.button("Create summary (Groq Llama)"):
82
- with st.spinner("Summarizing with Groq Llama..."):
83
- summary = llama_summarize(text)
84
- st.subheader("Summary")
85
- st.write(summary)
86
- st.session_state["summary"] = summary
87
-
88
- if "summary" in st.session_state:
89
- summary = st.session_state["summary"]
90
- if st.button("Synthesize audio from summary (TTS)"):
91
- with st.spinner("Creating audio..."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
- audio = tts_synthesize(summary)
94
- st.audio(audio)
95
- audio_download_button(audio)
 
 
 
96
  except Exception as e:
97
- st.error(f"TTS failed: {e}")
98
 
99
- st.markdown("---")
100
- st.subheader("Chat with your PDF (ask questions about document)")
101
- if "chat_history" not in st.session_state:
102
- # start with system + doc context (shortened)
103
- doc_context = (text[:4000] + "...") if len(text) > 4000 else text
104
- st.session_state["chat_history"] = [
105
- {"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."},
106
- {"role":"user","content": f"Document context:\n{doc_context}"}
107
- ]
108
-
109
- user_q = st.text_input("Ask a question about the PDF")
110
- if st.button("Ask") and user_q:
111
- with st.spinner("Getting answer from Groq Llama..."):
112
- answer = llama_chat(st.session_state["chat_history"], user_q)
113
- st.session_state.setdefault("convo", []).append(("You", user_q))
114
- st.session_state.setdefault("convo", []).append(("Assistant", answer))
115
- # append to history for next calls
116
- st.session_state["chat_history"].append({"role":"user","content":user_q})
117
- st.session_state["chat_history"].append({"role":"assistant","content":answer})
118
- st.write(answer)
119
 
120
  st.markdown("---")
121
- st.subheader("Generate a diagram from your question (SDXL)")
122
  diagram_prompt = st.text_input("Describe the diagram or scene to generate")
123
- if st.button("Generate diagram") and diagram_prompt:
124
- with st.spinner("Generating image (SDXL)..."):
125
  try:
126
  img = generate_image(diagram_prompt)
127
  st.image(img, use_column_width=True)
128
- # allow download
129
  buf = io.BytesIO()
130
  img.save(buf, format="PNG")
131
  st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
132
  except Exception as e:
133
- st.error(f"Image generation failed: {e}")
134
 
135
- st.sidebar.title("Settings")
136
- st.sidebar.write("Models in use:")
137
  st.sidebar.write(f"LLM: {LLAMA_MODEL}")
138
  st.sidebar.write(f"TTS: {TTS_MODEL}")
139
  st.sidebar.write(f"Image: {SDXL_MODEL}")
140
-
141
- st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.")
 
7
  import pdfplumber
8
  from PIL import Image
9
  import base64
10
+ from typing import Optional
11
+
12
+ st.set_page_config(page_title="PDF → Summary + TTS + Chat + Diagram", layout="wide")
13
+
14
+ # ---------- Config (models - change if you prefer others) ----------
15
+ LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF (example)
16
+ TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # example TTS model on HF
17
+ SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL model on HF
18
+
19
+ # ---------- Secrets: HF_TOKEN and GROQ_TOKEN ----------
20
+ HF_TOKEN = os.environ.get("HF_TOKEN")
21
+ GROQ_TOKEN = os.environ.get("GROQ_TOKEN")
22
+
23
+ # ---------- Init InferenceClient ----------
24
+ client: Optional[InferenceClient] = None
25
+ client_info = ""
26
+ try:
27
+ if GROQ_TOKEN:
28
+ # Prefer Groq provider if GROQ_TOKEN present
29
+ client = InferenceClient(provider="groq", api_key=GROQ_TOKEN)
30
+ client_info = "Using Groq provider (GROQ_TOKEN)"
31
+ elif HF_TOKEN:
32
+ client = InferenceClient(api_key=HF_TOKEN)
33
+ client_info = "Using Hugging Face Inference (HF_TOKEN)"
34
+ else:
35
+ client_info = "NO TOKEN FOUND"
36
+ except Exception as e:
37
+ client_info = f"Failed to initialize InferenceClient: {e}"
38
+ client = None
39
 
40
  # ---------- Helpers ----------
41
+ def pdf_to_text_bytes(file_bytes: bytes) -> str:
42
  text_chunks = []
43
+ with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
44
  for page in pdf.pages:
45
  ptext = page.extract_text()
46
  if ptext:
47
  text_chunks.append(ptext)
48
  return "\n\n".join(text_chunks)
49
 
50
+ def llama_summarize(text: str) -> str:
51
+ if client is None:
52
+ raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
53
+ # Create simple system+user prompt
54
+ messages = [
55
+ {"role": "system", "content": "You are a concise summarizer. Provide a short summary in bullet points."},
56
+ {"role": "user", "content": f"Summarize the following document in 6-8 concise bullet points:\n\n{text}"}
57
  ]
58
+ # Try chat completions API path, fallback to text generation if necessary
 
59
  try:
60
+ resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
61
+ return resp.choices[0].message["content"]
62
  except Exception:
63
+ try:
64
+ # fallback: text generation (single string)
65
+ resp2 = client.text_generation(model=LLAMA_MODEL, inputs="Summarize:\n\n" + text, max_new_tokens=512)
66
+ # resp2 may be dict-like or object; try a few access patterns
67
+ if isinstance(resp2, dict) and "generated_text" in resp2:
68
+ return resp2["generated_text"]
69
+ # try attribute access
70
+ return str(resp2)
71
+ except Exception as e:
72
+ raise RuntimeError(f"Summarization failed: {e}")
73
+
74
+ def llama_chat(chat_history: list, user_question: str) -> str:
75
+ if client is None:
76
+ raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
77
+ messages = chat_history + [{"role": "user", "content": user_question}]
78
+ try:
79
+ resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
80
+ return resp.choices[0].message["content"]
81
+ except Exception as e:
82
+ raise RuntimeError(f"Chat completion failed: {e}")
83
+
84
+ def tts_synthesize(text: str) -> bytes:
85
+ if client is None:
86
+ raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
87
+ try:
88
+ audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
89
+ return audio_bytes
90
+ except Exception as e:
91
+ raise RuntimeError(f"TTS failed: {e}")
92
+
93
+ def generate_image(prompt_text: str) -> Image.Image:
94
+ if client is None:
95
+ raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
96
+ try:
97
+ img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
98
+ return Image.open(io.BytesIO(img_bytes))
99
+ except Exception as e:
100
+ raise RuntimeError(f"Image generation failed: {e}")
101
+
102
+ def make_download_link_bytes(data: bytes, filename: str, mime: str):
103
+ b64 = base64.b64encode(data).decode()
104
+ href = f'<a href="data:{mime};base64,{b64}" download="{filename}">Download {filename}</a>'
105
+ return href
106
+
107
+ # ---------- UI ----------
108
+ st.title("PDF → Summary + TTS + Chat + Diagram (Groq/HF)")
109
+
110
+ st.sidebar.markdown("### Runtime info")
111
+ st.sidebar.write(client_info)
112
+ st.sidebar.markdown("**Required env vars**: `HF_TOKEN` and/or `GROQ_TOKEN`. Prefer `GROQ_TOKEN` for Groq provider.")
113
+
114
+ if client is None:
115
+ st.error("Inference client not initialized. Set HF_TOKEN or GROQ_TOKEN as environment variables in your Space.")
116
+ st.stop()
117
+
118
+ uploaded = st.file_uploader("Upload a PDF to analyze", type=["pdf"])
119
  if uploaded:
120
+ file_bytes = uploaded.read()
121
  with st.spinner("Extracting text from PDF..."):
122
+ try:
123
+ text = pdf_to_text_bytes(file_bytes)
124
+ except Exception as e:
125
+ st.error(f"Failed to extract text from PDF: {e}")
126
+ text = ""
127
+ st.subheader("Document preview (first 2000 chars)")
128
+ st.text_area("", value=(text[:2000] + ("..." if len(text) > 2000 else "")), height=220)
129
+
130
+ col1, col2 = st.columns(2)
131
+
132
+ with col1:
133
+ if st.button("Create summary"):
134
+ if not text.strip():
135
+ st.error("Document text empty or extraction failed.")
136
+ else:
137
+ with st.spinner("Summarizing with Llama..."):
138
+ try:
139
+ summary = llama_summarize(text)
140
+ st.session_state["summary"] = summary
141
+ st.subheader("Summary")
142
+ st.markdown(summary)
143
+ except Exception as e:
144
+ st.error(str(e))
145
+
146
+ if "summary" in st.session_state:
147
+ summary = st.session_state["summary"]
148
+ if st.button("Synthesize summary to audio"):
149
+ with st.spinner("Generating speech..."):
150
+ try:
151
+ wav = tts_synthesize(summary)
152
+ st.audio(wav)
153
+ st.markdown(make_download_link_bytes(wav, "summary.wav", "audio/wav"), unsafe_allow_html=True)
154
+ except Exception as e:
155
+ st.error(str(e))
156
+
157
+ with col2:
158
+ st.subheader("Chat with the document")
159
+ if "chat_history" not in st.session_state:
160
+ doc_context = text[:4000] if text else ""
161
+ st.session_state["chat_history"] = [
162
+ {"role":"system","content":"You are an assistant that answers questions based only on the provided document context."},
163
+ {"role":"user","content": f"Document context:\n{doc_context}"}
164
+ ]
165
+ st.session_state["convo_display"] = []
166
+
167
+ user_q = st.text_input("Ask a question about the PDF")
168
+ if st.button("Ask question") and user_q.strip():
169
+ with st.spinner("Getting answer from Llama..."):
170
  try:
171
+ answer = llama_chat(st.session_state["chat_history"], user_q)
172
+ # show and store
173
+ st.session_state["convo_display"].append(("You", user_q))
174
+ st.session_state["convo_display"].append(("Assistant", answer))
175
+ st.session_state["chat_history"].append({"role":"user","content":user_q})
176
+ st.session_state["chat_history"].append({"role":"assistant","content":answer})
177
  except Exception as e:
178
+ st.error(str(e))
179
 
180
+ # show conversation
181
+ for speaker, textline in st.session_state.get("convo_display", []):
182
+ if speaker == "You":
183
+ st.markdown(f"**You:** {textline}")
184
+ else:
185
+ st.markdown(f"**Assistant:** {textline}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  st.markdown("---")
188
+ st.subheader("Generate diagram/image from prompt (SDXL)")
189
  diagram_prompt = st.text_input("Describe the diagram or scene to generate")
190
+ if st.button("Generate diagram") and diagram_prompt.strip():
191
+ with st.spinner("Generating image..."):
192
  try:
193
  img = generate_image(diagram_prompt)
194
  st.image(img, use_column_width=True)
 
195
  buf = io.BytesIO()
196
  img.save(buf, format="PNG")
197
  st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
198
  except Exception as e:
199
+ st.error(str(e))
200
 
201
+ st.sidebar.markdown("---")
202
+ st.sidebar.markdown("### Model IDs (change in app.py if you want)")
203
  st.sidebar.write(f"LLM: {LLAMA_MODEL}")
204
  st.sidebar.write(f"TTS: {TTS_MODEL}")
205
  st.sidebar.write(f"Image: {SDXL_MODEL}")