File size: 2,987 Bytes
3d0f7c4
a0c938b
 
3d0f7c4
abfd83b
26a9c66
abfd83b
c51711e
5d4ec37
b88e8f9
c51711e
3d0f7c4
 
 
 
26a9c66
 
a0c938b
3d0f7c4
26a9c66
 
 
 
3d0f7c4
a0c938b
 
3d0f7c4
a0c938b
 
3d0f7c4
a0c938b
3d0f7c4
 
26a9c66
 
a0c938b
 
c51711e
ddcad02
3d0f7c4
 
 
 
c51711e
3d0f7c4
 
 
 
26a9c66
 
 
 
 
 
 
a0c938b
26a9c66
a0c938b
 
26a9c66
 
 
a0c938b
 
 
 
c51711e
3d0f7c4
 
 
 
 
 
ddcad02
c51711e
a0c938b
 
3d0f7c4
26a9c66
ddcad02
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from dotenv import load_dotenv
from functools import lru_cache

# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# App title and description
st.title("I am Your GrowBuddy 🌱")
st.write("Let me help you start gardening. Let's grow together!")

# Function to load model only once (with quantization for CPU optimization)
@st.cache_resource
def load_model():
    try:
        tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN, use_fast=True)
        # Quantized model for better CPU performance (with 8-bit precision)
        model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN, torch_dtype=torch.float32)
        return tokenizer, model
    except Exception as e:
        st.error(f"Failed to load model: {e}")
        return None, None

# Load model and tokenizer (cached)
tokenizer, model = load_model()

if not tokenizer or not model:
    st.stop()

# Ensure model is on CPU (set to float32 for better performance on CPU)
device = torch.device("cpu")
model = model.to(device)

# Initialize session state messages
if "messages" not in st.session_state:
    st.session_state.messages = [
        {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"}
    ]

# Display conversation history
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

# LRU Cache for repeated queries to avoid redundant computation
@lru_cache(maxsize=100)
def cached_generate_response(prompt, tokenizer, model):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    outputs = model.generate(inputs["input_ids"], max_new_tokens=50, temperature=0.7, do_sample=True)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Function to generate response with optimization
def generate_response(prompt):
    try:
        # Check cache for previous result (for repeated queries)
        cached_response = cached_generate_response(prompt, tokenizer, model)
        return cached_response
    except Exception as e:
        st.error(f"Error during text generation: {e}")
        return "Sorry, I couldn't process your request."

# User input field for gardening questions
user_input = st.chat_input("Type your gardening question here:")

if user_input:
    with st.chat_message("user"):
        st.write(user_input)

    with st.chat_message("assistant"):
        with st.spinner("Generating your answer..."):
            response = generate_response(user_input)
            st.write(response)

    # Update session state with new messages
    st.session_state.messages.append({"role": "user", "content": user_input})
    st.session_state.messages.append({"role": "assistant", "content": response})