Spaces:
Running
Running
File size: 5,921 Bytes
a94de35 0aad6fd e115381 790fcbd fa73971 d6dc06a fe63ca1 968f358 0c5711c 968f358 ef8f9db 88bcf19 0c5711c fe63ca1 492a988 e115381 0aad6fd 85c189a 0aad6fd e115381 0aad6fd e115381 0aad6fd a94de35 58a8659 e115381 e972582 0c5711c b8ce407 0c5711c 218b4bd 72ae0b1 0c5711c a94de35 e115381 e972582 51db61d 0aad6fd e115381 8d6c903 4dcc069 e48e44f 4dcc069 950465b e115381 f1ebe90 e972582 62c86e1 577cbf8 a94de35 45f3194 a94de35 e115381 950465b 289d0f6 ce7bb20 e115381 f1ebe90 580cfe1 a2373c5 fa73971 08dec69 fa73971 f1ebe90 d366837 e04907a 4a56dea ce7bb20 a94de35 b962028 ce7bb20 4a56dea a94de35 4a56dea ce7bb20 4a56dea a94de35 b962028 ce7bb20 4a56dea ce7bb20 e04907a ce7bb20 3366ac8 e972582 ce7bb20 a94de35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import streamlit as st
import hashlib
import os
import requests
import time
from langsmith import traceable
import random
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import List, Optional
from tqdm import tqdm
import re
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide")
tokenizer = None
model = None
model_name = "teapotai/teapotllm"
with st.spinner('Loading Model...'):
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="699ab39cbf586674806354e92fbd6179f9a95f4a")
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,revision="699ab39cbf586674806354e92fbd6179f9a95f4a")
def log_time(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
return result
return wrapper
API_KEY = os.environ.get("brave_api_key")
@log_time
def brave_search(query, count=3):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
print(results)
return [(res["title"], res["description"], res["url"]) for res in results]
else:
print(f"Error: {response.status_code}, {response.text}")
return []
@traceable
@log_time
def query_teapot(prompt, context, user_input):
input_text = prompt + "\n" + context + "\n" + user_input
start_time = time.time()
inputs = tokenizer(input_text, return_tensors="pt")
input_length = inputs["input_ids"].shape[1]
output = model.generate(**inputs, max_new_tokens=512)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
total_length = output.shape[1] # Includes both input and output tokens
output_length = total_length - input_length # Extract output token count
end_time = time.time()
elapsed_time = end_time - start_time
tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
return output_text
@log_time
def handle_chat(user_prompt, user_input):
with st.chat_message("user"):
st.markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
results = brave_search(user_input)
documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
st.sidebar.write("---")
st.sidebar.write("## RAG Documents")
for (title, description, url) in results:
# Display Results
st.sidebar.write(f"## {title}")
st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}")
st.sidebar.write(f"[Source]({url})")
st.sidebar.write("---")
context = "\n".join(documents)
prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. If a user asks who you are reply "I am Teapot"."""
response = query_teapot(prompt, context+user_prompt, user_input)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
return response
def main():
st.sidebar.header("Retrieval Augmented Generation")
user_prompt = st.sidebar.text_area("Enter prompt, leave empty for search")
list1 = ["Tell me about teapotllm", "What is Teapot AI?","What devices can Teapot run on?","Who are you?"]
list2 = ["Who invented quantum mechanics?", "Who are the authors of attention is all you need", "Tell me about popular places to travel in France","Summarize the book irobot", "Explain artificial intelligence","what are the key ingredients of bouillabaisse"]
list3 = ["Extract the year Google was founded", "Extract the last name of the father of artificial intelligence", "Output the capital of New York","Extarct the city where the louvre is located","Find the chemical symbol for gold","Extract the name of the woman who was the first computer programmer"]
# Randomly select one from each list
random_selection = [random.choice(list1), random.choice(list2), random.choice(list3)]
choice1 = random.choice(list1)
choice2 = random.choice(list2)
choice3 = random.choice(list3)
s1, s2, s3 = st.columns([1, 1, 1])
user_suggested_input = None
with s1:
if st.button(choice1, use_container_width=True):
user_suggested_input = choice1
with s2:
if st.button(choice2, use_container_width=True):
user_suggested_input = choice2
with s3:
if st.button(choice3, use_container_width=True):
user_suggested_input = choice3
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
user_input = st.chat_input("Ask me anything")
if user_input:
with st.spinner('Generating Response...'):
response = handle_chat(user_prompt, user_suggested_input or user_input)
if __name__ == "__main__":
main()
|