File size: 3,938 Bytes
ab222f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b35fe75
ab222f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
import json
from datetime import datetime

class ChatApp:
    def __init__(self):
        st.set_page_config(page_title="Inspection Methods Engineer Assistant", page_icon="πŸ”", layout="wide")
        self.initialize_session_state()
        self.model_handler = self.load_model()

    def initialize_session_state(self):
        if "messages" not in st.session_state:
            st.session_state.messages = [
                {"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: "}
            ]

    @staticmethod
    @st.cache_resource
    def load_model():
        device = "cuda" if torch.cuda.is_available() else "cpu"
        st.info(f"Using device: {device}")
        model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            load_in_8bit=device == "cuda"
        )
        return ModelHandler(model, tokenizer)

    def display_message(self, role, content):
        with st.chat_message(role):
            st.markdown(content)

    def get_user_input(self):
        return st.chat_input("Type your message here...")

    def stream_response(self, response):
        placeholder = st.empty()
        full_response = ""
        for word in response.split():
            full_response += word + " "
            placeholder.markdown(full_response + "β–Œ")
            time.sleep(0.01)
        placeholder.markdown(full_response)
        return full_response

    def save_chat_history(self):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"chat_history_{timestamp}.json"
        with open(filename, "w") as f:
            json.dump(st.session_state.messages, f, indent=2)
        return filename

    def run(self):
        st.title("Inspection Methods Engineer Assistant")
        
        for message in st.session_state.messages:
            if message["role"] != "system":
                self.display_message(message["role"], message["content"])

        user_input = self.get_user_input()
        if user_input:
            self.display_message("user", user_input)
            st.session_state.messages.append({"role": "user", "content": user_input})

            conversation = "\n\n".join([msg["content"] for msg in st.session_state.messages])

            with st.spinner("Analyzing and classifying scope..."):
                response = self.model_handler.generate_response(conversation.strip())

            clean_response = self.clean_response(response)
            with st.chat_message("assistant"):
                full_response = self.stream_response(clean_response)
            st.session_state.messages.append({"role": "assistant", "content": full_response})

        st.sidebar.title("Chat Options")
        if st.sidebar.button("Save Chat History"):
            filename = self.save_chat_history()
            st.sidebar.success(f"Chat history saved to {filename}")

    def clean_response(self, response):
        # Remove any system: or user: prefixes from the response
        lines = response.split('\n')
        clean_lines = [line.split(':', 1)[-1].strip() if ':' in line else line for line in lines]
        return '\n'.join(clean_lines)

class ModelHandler:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate_response(self, conversation):
        inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(**inputs, max_new_tokens=100)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

if __name__ == "__main__":
    app = ChatApp()
    app.run()