amiguel's picture
Update app.py
b35fe75 verified
raw
history blame contribute delete
No virus
3.94 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
import json
from datetime import datetime
class ChatApp:
def __init__(self):
st.set_page_config(page_title="Inspection Methods Engineer Assistant", page_icon="πŸ”", layout="wide")
self.initialize_session_state()
self.model_handler = self.load_model()
def initialize_session_state(self):
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: "}
]
@staticmethod
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
st.info(f"Using device: {device}")
model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
load_in_8bit=device == "cuda"
)
return ModelHandler(model, tokenizer)
def display_message(self, role, content):
with st.chat_message(role):
st.markdown(content)
def get_user_input(self):
return st.chat_input("Type your message here...")
def stream_response(self, response):
placeholder = st.empty()
full_response = ""
for word in response.split():
full_response += word + " "
placeholder.markdown(full_response + "β–Œ")
time.sleep(0.01)
placeholder.markdown(full_response)
return full_response
def save_chat_history(self):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.json"
with open(filename, "w") as f:
json.dump(st.session_state.messages, f, indent=2)
return filename
def run(self):
st.title("Inspection Methods Engineer Assistant")
for message in st.session_state.messages:
if message["role"] != "system":
self.display_message(message["role"], message["content"])
user_input = self.get_user_input()
if user_input:
self.display_message("user", user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
conversation = "\n\n".join([msg["content"] for msg in st.session_state.messages])
with st.spinner("Analyzing and classifying scope..."):
response = self.model_handler.generate_response(conversation.strip())
clean_response = self.clean_response(response)
with st.chat_message("assistant"):
full_response = self.stream_response(clean_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.sidebar.title("Chat Options")
if st.sidebar.button("Save Chat History"):
filename = self.save_chat_history()
st.sidebar.success(f"Chat history saved to {filename}")
def clean_response(self, response):
# Remove any system: or user: prefixes from the response
lines = response.split('\n')
clean_lines = [line.split(':', 1)[-1].strip() if ':' in line else line for line in lines]
return '\n'.join(clean_lines)
class ModelHandler:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate_response(self, conversation):
inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=100)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if __name__ == "__main__":
app = ChatApp()
app.run()