Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from typing import List, Dict, Any | |
| from serpapi import GoogleSearch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| def search_serpapi(query: str, loc: str, api_key: str) -> List[Dict[str, Any]]: | |
| """ | |
| Search using SerpAPI for the given query and return the results. | |
| """ | |
| try: | |
| search = GoogleSearch({ | |
| "q": query, | |
| "location": loc, | |
| "api_key": api_key | |
| }) | |
| results = search.get_dict() | |
| return results.get("organic_results", []) | |
| except Exception as e: | |
| raise Exception(f"An error occurred: {e}") | |
| def convert_to_md_table(data): | |
| md_table = "| Title | Link |\n| :--- | :--- |\n" | |
| for item in data: | |
| title = item['title'] | |
| link = item['link'] | |
| md_table += f"| {title} | [Link]({link}) |\n" | |
| return md_table | |
| # Load model directly | |
| tokenizer = AutoTokenizer.from_pretrained("jy46604790/Fake-News-Bert-Detect") | |
| model = AutoModelForSequenceClassification.from_pretrained("jy46604790/Fake-News-Bert-Detect") | |
| def call_classifier(text: str): | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| label = torch.argmax(probabilities, dim=1).item() | |
| score = probabilities[0][label].item() | |
| return {"label": label, "score": score} | |
| # Initialize session state | |
| if 'history' not in st.session_state: | |
| st.session_state.history = [] | |
| if 'user_input' not in st.session_state: | |
| st.session_state.user_input = "" | |
| if 'score' not in st.session_state: | |
| st.session_state.score = {'label': 'LABEL_0', 'score': 0.0} | |
| # Streamlit app layout | |
| st.title("Chatbot News Search") | |
| # User input | |
| st.session_state.user_input = st.text_input("What news do you want to search?", st.session_state.user_input) | |
| # Threshold | |
| threshold = 0.7 | |
| # Main logic | |
| if st.session_state.user_input: | |
| st.session_state.history.append(f"User: {st.session_state.user_input}") | |
| if st.session_state.score['score'] > threshold: | |
| query = st.session_state.user_input | |
| SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") # Use environment variable for SerpAPI key | |
| if not SERPAPI_API_KEY: | |
| st.error("SerpAPI API key not found. Please set the SERPAPI_API_KEY environment variable.") | |
| else: | |
| news_results = search_serpapi(query, "New York", SERPAPI_API_KEY) | |
| formatted_news = convert_to_md_table(news_results) | |
| st.session_state.history.append(f"Chatbot: Here are the latest news results:\n{formatted_news}") | |
| user_continue = st.radio("Are you okay with this?", ('Y', 'E')) | |
| if user_continue == 'E': | |
| st.session_state.history.append("User exited the conversation.") | |
| else: | |
| new_score = call_classifier(query) | |
| while new_score['score'] < st.session_state.score['score']: | |
| st.session_state.history.append("Run the SerpAPI again.") | |
| new_score = call_classifier(query)['score'] | |
| st.session_state.history.append(f"New score: {new_score}") | |
| user_continue = st.radio("Are you okay with this?", ('Y', 'E')) | |
| if user_continue == 'E': | |
| st.session_state.history.append("User exited the conversation.") | |
| break | |
| st.session_state.score = new_score | |
| else: | |
| st.session_state.history.append(f"Chatbot: Current score: {st.session_state.score['score']}") | |
| st.session_state.user_input = st.text_input("Please provide more information to refine the news search:", st.session_state.user_input) | |
| st.session_state.score = call_classifier(st.session_state.user_input) | |
| st.session_state.history.append(f"New score: {st.session_state.score['score']}") | |
| # Display chat history | |
| for message in st.session_state.history: | |
| st.write(message) | |
| st.write("If you want to finish the conversation, please enter 'exit'.") | |