Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import torch | |
import time | |
from typing import List, Dict | |
import functools | |
import signal | |
class TimeoutError(Exception): | |
pass | |
def timeout(seconds): | |
def decorator(func): | |
def wrapper(*args, **kwargs): | |
def handler(signum, frame): | |
raise TimeoutError(f"Function call timed out after {seconds} seconds") | |
# Set the timeout handler | |
signal.signal(signal.SIGALRM, handler) | |
signal.alarm(seconds) | |
try: | |
result = func(*args, **kwargs) | |
finally: | |
# Disable the alarm | |
signal.alarm(0) | |
return result | |
return wrapper | |
return decorator | |
class SourceVerifier: | |
def __init__(self): | |
self.sources: List[Dict] = [] | |
def add_source(self, text: str, metadata: Dict) -> None: | |
self.sources.append({"content": text, "metadata": metadata}) | |
def verify_statement(self, statement: str) -> Dict: | |
matches = [] | |
for source in self.sources: | |
if any(word.lower() in source["content"].lower() | |
for word in statement.split()): | |
matches.append(source) | |
return { | |
"verified": len(matches) > 0, | |
"matches": matches, | |
"confidence": len(matches) / len(self.sources) if self.sources else 0 | |
} | |
def load_pipeline(): | |
try: | |
return pipeline( | |
"text-generation", | |
model="sshleifer/tiny-gpt2", # Tiny 2M parameter model | |
device="cpu", # Force CPU usage | |
model_kwargs={"low_memory": True} | |
) | |
except Exception as e: | |
st.error(f"Failed to load model: {str(e)}") | |
return None | |
# 10 second timeout | |
def generate_response(generator, prompt: str) -> str: | |
try: | |
result = generator( | |
prompt, | |
max_length=50, # Short response | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
) | |
return result[0]['generated_text'] | |
except TimeoutError: | |
return "Response generation timed out. Please try again." | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def init_page(): | |
st.set_page_config( | |
page_title="Quick Chat Demo", | |
page_icon="π¬", | |
layout="centered" | |
) | |
st.title("Quick Chat Demo") | |
if "messages" not in st.session_state: | |
st.session_state.messages = [ | |
{"role": "assistant", "content": "Hi! I'm a simple chat demo. How can I help?"} | |
] | |
if "verifier" not in st.session_state: | |
st.session_state.verifier = SourceVerifier() | |
def handle_file_upload(): | |
uploaded_file = st.file_uploader("Upload source document", type=["txt", "md", "json"]) | |
if uploaded_file: | |
try: | |
content = uploaded_file.read().decode() | |
st.session_state.verifier.add_source( | |
content, | |
{"filename": uploaded_file.name, "type": uploaded_file.type} | |
) | |
st.success(f"Added source: {uploaded_file.name}") | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
def main(): | |
init_page() | |
# Load the model with a progress bar | |
with st.spinner("Loading (should take < 5 seconds)..."): | |
generator = load_pipeline() | |
if generator is None: | |
st.error("Failed to initialize chat. Please refresh the page.") | |
return | |
# Sidebar for document upload | |
with st.sidebar: | |
st.header("Sources") | |
handle_file_upload() | |
# Display existing messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Chat input | |
if prompt := st.chat_input("Say something"): | |
# Add user message | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate response with timeout | |
with st.chat_message("assistant"): | |
with st.spinner("Responding..."): | |
response = generate_response(generator, prompt) | |
verification = st.session_state.verifier.verify_statement(response) | |
st.write(response) | |
if verification["verified"]: | |
with st.expander("Sources"): | |
st.json(verification) | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": response, | |
"verification": verification | |
}) | |
if __name__ == "__main__": | |
main() |