kzz1027's picture
2cbca4c verified
import os
import streamlit as st
from typing import List, Dict, Any
from serpapi import GoogleSearch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
def search_serpapi(query: str, loc: str, api_key: str) -> List[Dict[str, Any]]:
Search using SerpAPI for the given query and return the results.
search = GoogleSearch({
"q": query,
"location": loc,
"api_key": api_key
results = search.get_dict()
return results.get("organic_results", [])
except Exception as e:
raise Exception(f"An error occurred: {e}")
def convert_to_md_table(data):
md_table = "| Title | Link |\n| :--- | :--- |\n"
for item in data:
title = item['title']
link = item['link']
md_table += f"| {title} | [Link]({link}) |\n"
return md_table
# Load model directly
tokenizer = AutoTokenizer.from_pretrained("jy46604790/Fake-News-Bert-Detect")
model = AutoModelForSequenceClassification.from_pretrained("jy46604790/Fake-News-Bert-Detect")
def call_classifier(text: str):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
label = torch.argmax(probabilities, dim=1).item()
score = probabilities[0][label].item()
return {"label": label, "score": score}
# Initialize session state
if 'history' not in st.session_state:
st.session_state.history = []
if 'user_input' not in st.session_state:
st.session_state.user_input = ""
if 'score' not in st.session_state:
st.session_state.score = {'label': 'LABEL_0', 'score': 0.0}
# Streamlit app layout
st.title("Chatbot News Search")
# User input
st.session_state.user_input = st.text_input("What news do you want to search?", st.session_state.user_input)
# Threshold
threshold = 0.7
# Main logic
if st.session_state.user_input:
st.session_state.history.append(f"User: {st.session_state.user_input}")
if st.session_state.score['score'] > threshold:
query = st.session_state.user_input
SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") # Use environment variable for SerpAPI key
st.error("SerpAPI API key not found. Please set the SERPAPI_API_KEY environment variable.")
news_results = search_serpapi(query, "New York", SERPAPI_API_KEY)
formatted_news = convert_to_md_table(news_results)
st.session_state.history.append(f"Chatbot: Here are the latest news results:\n{formatted_news}")
user_continue ="Are you okay with this?", ('Y', 'E'))
if user_continue == 'E':
st.session_state.history.append("User exited the conversation.")
new_score = call_classifier(query)
while new_score['score'] < st.session_state.score['score']:
st.session_state.history.append("Run the SerpAPI again.")
new_score = call_classifier(query)['score']
st.session_state.history.append(f"New score: {new_score}")
user_continue ="Are you okay with this?", ('Y', 'E'))
if user_continue == 'E':
st.session_state.history.append("User exited the conversation.")
st.session_state.score = new_score
st.session_state.history.append(f"Chatbot: Current score: {st.session_state.score['score']}")
st.session_state.user_input = st.text_input("Please provide more information to refine the news search:", st.session_state.user_input)
st.session_state.score = call_classifier(st.session_state.user_input)
st.session_state.history.append(f"New score: {st.session_state.score['score']}")
# Display chat history
for message in st.session_state.history:
st.write("If you want to finish the conversation, please enter 'exit'.")