pdf-summarizer / app.py
ChatBotsTA's picture
Update app.py
a257837 verified
# app.py
import os
import io
import tempfile
import streamlit as st
from huggingface_hub import InferenceClient
import pdfplumber
from PIL import Image
import base64
# ---------- Configuration ----------
HF_TOKEN = os.environ.get("HF_TOKEN") # required
GROQ_KEY = os.environ.get("GROQ_API_KEY") # optional: if you want to call Groq directly
USE_GROQ_PROVIDER = True # set False to route to default HF provider
# model IDs (change if you prefer other models)
LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF
TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # a HF-hosted TTS model example
SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL base model
# create Inference client (route via HF token by default)
if USE_GROQ_PROVIDER:
client = InferenceClient(provider="groq", api_key=HF_TOKEN)
else:
client = InferenceClient(api_key=HF_TOKEN)
# ---------- Helpers ----------
def pdf_to_text(uploaded_file) -> str:
text_chunks = []
with pdfplumber.open(uploaded_file) as pdf:
for page in pdf.pages:
ptext = page.extract_text()
if ptext:
text_chunks.append(ptext)
return "\n\n".join(text_chunks)
def llama_summarize(text, max_tokens=512):
prompt = [
{"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."},
{"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"}
]
# Use chat completion endpoint style
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt)
try:
summary = resp.choices[0].message["content"]
except Exception:
# fallback: try text generation field
summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp)
return summary
def llama_chat(chat_history, user_question):
messages = chat_history + [{"role":"user","content":user_question}]
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
return resp.choices[0].message["content"]
def tts_synthesize(text) -> bytes:
# InferenceClient offers text->audio utilities. This returns raw audio bytes (wav).
audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
return audio_bytes
def generate_image(prompt_text) -> Image.Image:
img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
return Image.open(io.BytesIO(img_bytes))
def audio_download_button(wav_bytes, filename="summary.wav"):
b64 = base64.b64encode(wav_bytes).decode()
href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>'
st.markdown(href, unsafe_allow_html=True)
# ---------- Streamlit UI ----------
st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide")
st.title("PDF → Summary + Speech + Chat + Diagram (Groq + HF)")
uploaded = st.file_uploader("Upload PDF", type=["pdf"])
if uploaded:
with st.spinner("Extracting text from PDF..."):
text = pdf_to_text(uploaded)
st.subheader("Extracted text (preview)")
st.text_area("Document text", value=text[:1000], height=200)
if st.button("Create summary (Groq Llama)"):
with st.spinner("Summarizing with Groq Llama..."):
summary = llama_summarize(text)
st.subheader("Summary")
st.write(summary)
st.session_state["summary"] = summary
if "summary" in st.session_state:
summary = st.session_state["summary"]
if st.button("Synthesize audio from summary (TTS)"):
with st.spinner("Creating audio..."):
try:
audio = tts_synthesize(summary)
st.audio(audio)
audio_download_button(audio)
except Exception as e:
st.error(f"TTS failed: {e}")
st.markdown("---")
st.subheader("Chat with your PDF (ask questions about document)")
if "chat_history" not in st.session_state:
# start with system + doc context (shortened)
doc_context = (text[:4000] + "...") if len(text) > 4000 else text
st.session_state["chat_history"] = [
{"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."},
{"role":"user","content": f"Document context:\n{doc_context}"}
]
user_q = st.text_input("Ask a question about the PDF")
if st.button("Ask") and user_q:
with st.spinner("Getting answer from Groq Llama..."):
answer = llama_chat(st.session_state["chat_history"], user_q)
st.session_state.setdefault("convo", []).append(("You", user_q))
st.session_state.setdefault("convo", []).append(("Assistant", answer))
# append to history for next calls
st.session_state["chat_history"].append({"role":"user","content":user_q})
st.session_state["chat_history"].append({"role":"assistant","content":answer})
st.write(answer)
st.markdown("---")
st.subheader("Generate a diagram from your question (SDXL)")
diagram_prompt = st.text_input("Describe the diagram or scene to generate")
if st.button("Generate diagram") and diagram_prompt:
with st.spinner("Generating image (SDXL)..."):
try:
img = generate_image(diagram_prompt)
st.image(img, use_column_width=True)
# allow download
buf = io.BytesIO()
img.save(buf, format="PNG")
st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
except Exception as e:
st.error(f"Image generation failed: {e}")
st.sidebar.title("Settings")
st.sidebar.write("Models in use:")
st.sidebar.write(f"LLM: {LLAMA_MODEL}")
st.sidebar.write(f"TTS: {TTS_MODEL}")
st.sidebar.write(f"Image: {SDXL_MODEL}")
st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.")