junaidbaber commited on
Commit
fccfdf4
·
verified ·
1 Parent(s): ef628bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -57
app.py CHANGED
@@ -1,63 +1,137 @@
1
-
2
  from huggingface_hub import login
 
 
3
  import os
4
- token = os.environ.get("hf")
5
- login(token)
6
 
7
- import streamlit as st
8
- from transformers import pipeline
9
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Model ID
12
- MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
13
-
14
- @st.cache_resource
15
- def load_pipeline():
16
- try:
17
- st.write("Loading the instruct pipeline...")
18
- instruct_pipeline = pipeline(
19
- "text-generation",
20
- model=MODEL_ID,
21
- model_kwargs={"torch_dtype": torch.bfloat16},
22
- device_map="auto",
23
- )
24
- st.write("Pipeline successfully loaded.")
25
- return instruct_pipeline
26
- except Exception as e:
27
- st.error(f"Error loading pipeline: {e}")
28
- return None
29
-
30
- # Load the pipeline
31
- instruct_pipeline = load_pipeline()
32
-
33
- # Streamlit UI
34
- st.title("Instruction Chatbot")
35
- st.write("Chat with the instruction-tuned model!")
36
-
37
- if instruct_pipeline is None:
38
- st.error("Pipeline failed to load. Please check the configuration.")
39
- else:
40
- # Message-based interaction
41
- system_message = st.text_area("System Message", value="You are a helpful assistant.", height=100)
42
- user_input = st.text_input("User:", placeholder="Ask a question or provide an instruction...")
43
-
44
- if st.button("Send"):
45
- if user_input.strip():
46
- try:
47
- messages = [
48
- {"role": "system", "content": system_message},
49
- {"role": "user", "content": user_input},
50
- ]
51
- # Generate response
52
- outputs = instruct_pipeline(
53
- messages,
54
- max_new_tokens=150, # Limit response length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
- # Display the generated response
57
- response = outputs[0]["generated_text"]
58
- st.write(f"**Assistant:** {response}")
59
- except Exception as e:
60
- st.error(f"Error generating response: {e}")
61
- else:
62
- st.warning("Please enter a valid message.")
 
 
 
 
 
63
 
 
 
 
1
+ import streamlit as st
2
  from huggingface_hub import login
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ import torch
5
  import os
 
 
6
 
7
+ def initialize_model():
8
+ """Initialize the model and tokenizer"""
9
+ # Log in to Hugging Face
10
+ token = os.environ.get("hf")
11
+ login(token)
12
+
13
+ # Define the model ID and device
14
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Configure INT8 quantization
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_8bit=True,
20
+ llm_int8_enable_fp32_cpu_offload=True
21
+ )
22
+
23
+ # Load tokenizer and model
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ quantization_config=bnb_config,
28
+ device_map="auto"
29
+ )
30
+
31
+ # Ensure padding token is defined
32
+ if tokenizer.pad_token is None:
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+
35
+ return model, tokenizer, device
36
+
37
+ def format_conversation(conversation_history):
38
+ """Format the conversation history into a single string."""
39
+ formatted = ""
40
+ for turn in conversation_history:
41
+ formatted += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
42
+ return formatted.strip()
43
+
44
+ def generate_response(model, tokenizer, device, prompt, conversation_history):
45
+ """Generate model response"""
46
+ # Format the entire conversation context
47
+ context = format_conversation(conversation_history[:-1])
48
+ if context:
49
+ full_prompt = f"{context}\nUser: {prompt}"
50
+ else:
51
+ full_prompt = f"User: {prompt}"
52
 
53
+ # Tokenize input
54
+ inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
55
+
56
+ # Calculate max new tokens
57
+ input_length = inputs["input_ids"].shape[1]
58
+ max_model_length = 2048
59
+ max_new_tokens = min(200, max_model_length - input_length)
60
+
61
+ # Generate response
62
+ outputs = model.generate(
63
+ inputs["input_ids"],
64
+ attention_mask=inputs["attention_mask"],
65
+ max_new_tokens=max_new_tokens,
66
+ temperature=0.7,
67
+ top_p=0.9,
68
+ pad_token_id=tokenizer.pad_token_id,
69
+ do_sample=True,
70
+ min_length=20,
71
+ no_repeat_ngram_size=3
72
+ )
73
+
74
+ # Decode response
75
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ response_parts = response.split("User: ")
77
+ model_response = response_parts[-1].split("Assistant: ")[-1].strip()
78
+
79
+ return model_response
80
+
81
+ def main():
82
+ st.set_page_config(page_title="LLM Chat Interface", page_icon="🤖")
83
+ st.title("Chat with LLM 🤖")
84
+
85
+ # Initialize session state for chat history
86
+ if "chat_history" not in st.session_state:
87
+ st.session_state.chat_history = []
88
+
89
+ # Initialize model (only once)
90
+ if "model" not in st.session_state:
91
+ with st.spinner("Loading the model... This might take a minute..."):
92
+ model, tokenizer, device = initialize_model()
93
+ st.session_state.model = model
94
+ st.session_state.tokenizer = tokenizer
95
+ st.session_state.device = device
96
+
97
+ # Display chat messages
98
+ for message in st.session_state.chat_history:
99
+ with st.chat_message("user"):
100
+ st.write(message["user"])
101
+ with st.chat_message("assistant"):
102
+ st.write(message["assistant"])
103
+
104
+ # Chat input
105
+ if prompt := st.chat_input("What would you like to know?"):
106
+ # Display user message
107
+ with st.chat_message("user"):
108
+ st.write(prompt)
109
+
110
+ # Generate and display assistant response
111
+ with st.chat_message("assistant"):
112
+ with st.spinner("Thinking..."):
113
+ current_turn = {"user": prompt, "assistant": ""}
114
+ st.session_state.chat_history.append(current_turn)
115
+
116
+ response = generate_response(
117
+ st.session_state.model,
118
+ st.session_state.tokenizer,
119
+ st.session_state.device,
120
+ prompt,
121
+ st.session_state.chat_history
122
  )
123
+
124
+ st.write(response)
125
+ st.session_state.chat_history[-1]["assistant"] = response
126
+
127
+ # Manage context window
128
+ if len(st.session_state.chat_history) > 5:
129
+ st.session_state.chat_history = st.session_state.chat_history[-5:]
130
+
131
+ # Add a clear chat button
132
+ if st.sidebar.button("Clear Chat"):
133
+ st.session_state.chat_history = []
134
+ st.rerun()
135
 
136
+ if __name__ == "__main__":
137
+ main()