Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from typing import List, Dict, Any | |
from serpapi import GoogleSearch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
def search_serpapi(query: str, loc: str, api_key: str) -> List[Dict[str, Any]]: | |
""" | |
Search using SerpAPI for the given query and return the results. | |
""" | |
try: | |
search = GoogleSearch({ | |
"q": query, | |
"location": loc, | |
"api_key": api_key | |
}) | |
results = search.get_dict() | |
return results.get("organic_results", []) | |
except Exception as e: | |
raise Exception(f"An error occurred: {e}") | |
def convert_to_md_table(data): | |
md_table = "| Title | Link |\n| :--- | :--- |\n" | |
for item in data: | |
title = item['title'] | |
link = item['link'] | |
md_table += f"| {title} | [Link]({link}) |\n" | |
return md_table | |
# Load model directly | |
tokenizer = AutoTokenizer.from_pretrained("jy46604790/Fake-News-Bert-Detect") | |
model = AutoModelForSequenceClassification.from_pretrained("jy46604790/Fake-News-Bert-Detect") | |
def call_classifier(text: str): | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
label = torch.argmax(probabilities, dim=1).item() | |
score = probabilities[0][label].item() | |
return {"label": label, "score": score} | |
# Initialize session state | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
if 'user_input' not in st.session_state: | |
st.session_state.user_input = "" | |
if 'score' not in st.session_state: | |
st.session_state.score = {'label': 'LABEL_0', 'score': 0.0} | |
# Streamlit app layout | |
st.title("Chatbot News Search") | |
# User input | |
st.session_state.user_input = st.text_input("What news do you want to search?", st.session_state.user_input) | |
# Threshold | |
threshold = 0.7 | |
# Main logic | |
if st.session_state.user_input: | |
st.session_state.history.append(f"User: {st.session_state.user_input}") | |
if st.session_state.score['score'] > threshold: | |
query = st.session_state.user_input | |
SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") # Use environment variable for SerpAPI key | |
if not SERPAPI_API_KEY: | |
st.error("SerpAPI API key not found. Please set the SERPAPI_API_KEY environment variable.") | |
else: | |
news_results = search_serpapi(query, "New York", SERPAPI_API_KEY) | |
formatted_news = convert_to_md_table(news_results) | |
st.session_state.history.append(f"Chatbot: Here are the latest news results:\n{formatted_news}") | |
user_continue = st.radio("Are you okay with this?", ('Y', 'E')) | |
if user_continue == 'E': | |
st.session_state.history.append("User exited the conversation.") | |
else: | |
new_score = call_classifier(query) | |
while new_score['score'] < st.session_state.score['score']: | |
st.session_state.history.append("Run the SerpAPI again.") | |
new_score = call_classifier(query)['score'] | |
st.session_state.history.append(f"New score: {new_score}") | |
user_continue = st.radio("Are you okay with this?", ('Y', 'E')) | |
if user_continue == 'E': | |
st.session_state.history.append("User exited the conversation.") | |
break | |
st.session_state.score = new_score | |
else: | |
st.session_state.history.append(f"Chatbot: Current score: {st.session_state.score['score']}") | |
st.session_state.user_input = st.text_input("Please provide more information to refine the news search:", st.session_state.user_input) | |
st.session_state.score = call_classifier(st.session_state.user_input) | |
st.session_state.history.append(f"New score: {st.session_state.score['score']}") | |
# Display chat history | |
for message in st.session_state.history: | |
st.write(message) | |
st.write("If you want to finish the conversation, please enter 'exit'.") | |