File size: 5,921 Bytes
a94de35
 
0aad6fd
 
e115381
790fcbd
fa73971
d6dc06a
fe63ca1
968f358
 
 
 
 
 
 
 
 
 
 
0c5711c
968f358
ef8f9db
88bcf19
 
 
0c5711c
fe63ca1
 
492a988
e115381
 
 
 
 
 
 
 
 
0aad6fd
85c189a
0aad6fd
e115381
0aad6fd
 
 
e115381
0aad6fd
 
 
 
 
 
 
 
 
 
a94de35
58a8659
e115381
e972582
0c5711c
 
 
 
 
 
 
b8ce407
0c5711c
 
 
 
 
 
 
 
 
 
218b4bd
72ae0b1
0c5711c
a94de35
e115381
e972582
51db61d
 
 
 
0aad6fd
e115381
 
8d6c903
 
4dcc069
 
 
e48e44f
4dcc069
 
950465b
e115381
f1ebe90
e972582
62c86e1
 
 
 
 
577cbf8
a94de35
 
45f3194
a94de35
e115381
950465b
289d0f6
ce7bb20
e115381
f1ebe90
580cfe1
a2373c5
fa73971
08dec69
 
fa73971
f1ebe90
 
 
d366837
e04907a
4a56dea
ce7bb20
 
a94de35
b962028
ce7bb20
4a56dea
a94de35
4a56dea
ce7bb20
4a56dea
a94de35
b962028
ce7bb20
 
 
 
4a56dea
ce7bb20
 
 
e04907a
ce7bb20
 
 
3366ac8
e972582
ce7bb20
a94de35
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import streamlit as st
import hashlib
import os
import requests
import time
from langsmith import traceable
import random


from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import List, Optional
from tqdm import tqdm
import re
import os

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide")
tokenizer = None
model = None
model_name = "teapotai/teapotllm"
with st.spinner('Loading Model...'):
    tokenizer = AutoTokenizer.from_pretrained(model_name, revision="699ab39cbf586674806354e92fbd6179f9a95f4a")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name,revision="699ab39cbf586674806354e92fbd6179f9a95f4a")

def log_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
        return result
    return wrapper


API_KEY = os.environ.get("brave_api_key")

@log_time
def brave_search(query, count=3):
    url = "https://api.search.brave.com/res/v1/web/search"
    headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
    params = {"q": query, "count": count}
    
    response = requests.get(url, headers=headers, params=params)
    
    if response.status_code == 200:
        results = response.json().get("web", {}).get("results", [])
        print(results)
        return [(res["title"], res["description"], res["url"]) for res in results]
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return []

@traceable 
@log_time
def query_teapot(prompt, context, user_input):
    input_text = prompt + "\n" + context + "\n" + user_input
    
    start_time = time.time()
    
    inputs = tokenizer(input_text, return_tensors="pt")
    input_length = inputs["input_ids"].shape[1]
    
    output = model.generate(**inputs, max_new_tokens=512)
    
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    total_length = output.shape[1]  # Includes both input and output tokens
    output_length = total_length - input_length  # Extract output token count
    
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
    
    return output_text



@log_time
def handle_chat(user_prompt, user_input):
    with st.chat_message("user"):
        st.markdown(user_input)
        st.session_state.messages.append({"role": "user", "content": user_input})
        
    results = brave_search(user_input)
    
    documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
    st.sidebar.write("---")
    st.sidebar.write("## RAG Documents")
    for (title, description, url) in results:
        # Display Results 
        st.sidebar.write(f"## {title}")
        st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}")
        st.sidebar.write(f"[Source]({url})")
        st.sidebar.write("---")

    context = "\n".join(documents)
    prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. If a user asks who you are reply "I am Teapot"."""
    response = query_teapot(prompt, context+user_prompt, user_input)

    with st.chat_message("assistant"):
        st.markdown(response)
        st.session_state.messages.append({"role": "assistant", "content": response})
    
    
    return response


def main():
    
    st.sidebar.header("Retrieval Augmented Generation")
    user_prompt = st.sidebar.text_area("Enter prompt, leave empty for search")

    
    list1 = ["Tell me about teapotllm", "What is Teapot AI?","What devices can Teapot run on?","Who are you?"]
    list2 = ["Who invented quantum mechanics?", "Who are the authors of attention is all you need", "Tell me about popular places to travel in France","Summarize the book irobot", "Explain artificial intelligence","what are the key ingredients of bouillabaisse"]
    list3 = ["Extract the year Google was founded", "Extract the last name of the father of artificial intelligence", "Output the capital of New York","Extarct the city where the louvre is located","Find the chemical symbol for gold","Extract the name of the woman who was the first computer programmer"]

    # Randomly select one from each list
    random_selection = [random.choice(list1), random.choice(list2), random.choice(list3)]

    choice1 = random.choice(list1)
    choice2 = random.choice(list2)
    choice3 = random.choice(list3)
    
    s1, s2, s3 = st.columns([1, 1, 1])
    
    user_suggested_input = None
    
    with s1:
        if st.button(choice1, use_container_width=True):
            user_suggested_input = choice1
    
    with s2:
        if st.button(choice2, use_container_width=True):
            user_suggested_input = choice2
    
    with s3:
        if st.button(choice3, use_container_width=True):
            user_suggested_input = choice3

    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}]
    
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    user_input = st.chat_input("Ask me anything")
        
    if user_input:
        with st.spinner('Generating Response...'):
            response = handle_chat(user_prompt, user_suggested_input or user_input)
           

if __name__ == "__main__":
    main()