Exched commited on
Commit
484e942
Β·
verified Β·
1 Parent(s): 31ddf35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from accelerate import Accelerator
4
+ from PIL import Image
5
+ import random
6
+ import requests
7
+ import streamlit as st
8
+ from transformers import BlipProcessor, BlipForConditionalGeneration
9
+ from langchain_huggingface import HuggingFaceEndpoint
10
+ from langchain_core.prompts import PromptTemplate
11
+ from langchain_core.output_parsers import StrOutputParser
12
+
13
+ # Define the model IDs
14
+ llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3"
15
+ blip_model_id = "Salesforce/blip-image-captioning-large"
16
+
17
+ # Initialize BLIP processor and model
18
+ processor = BlipProcessor.from_pretrained(blip_model_id)
19
+ model = BlipForConditionalGeneration.from_pretrained(blip_model_id)
20
+
21
+ # Initialize the accelerator
22
+ accelerator = Accelerator()
23
+
24
+ def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1):
25
+ try:
26
+ llm = HuggingFaceEndpoint(
27
+ repo_id=model_id,
28
+ max_new_tokens=max_new_tokens,
29
+ temperature=temperature,
30
+ token=os.getenv("HF_TOKEN")
31
+ )
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {e}")
34
+ llm = None
35
+ return llm
36
+
37
+ def generate_caption(image, min_len=30, max_len=100):
38
+ try:
39
+ inputs = processor(image, return_tensors="pt")
40
+ out = model.generate(**inputs, min_length=min_len, max_length=max_len)
41
+ caption = processor.decode(out[0], skip_special_tokens=True)
42
+ return caption
43
+ except Exception as e:
44
+ st.error(f"Error generating caption: {e}")
45
+ return 'Unable to generate caption.'
46
+
47
+ # Configure the Streamlit app
48
+ st.set_page_config(page_title="HuggingFace ChatBot", page_icon="πŸ€—")
49
+ st.title("Personal HuggingFace ChatBot")
50
+ st.markdown(f"*This is a simple chatbot using the HuggingFace transformers library with {llm_model_id}.*")
51
+
52
+ # Initialize session state
53
+ if "avatars" not in st.session_state:
54
+ st.session_state.avatars = {'user': None, 'assistant': None}
55
+
56
+ if 'user_text' not in st.session_state:
57
+ st.session_state.user_text = None
58
+
59
+ if "max_response_length" not in st.session_state:
60
+ st.session_state.max_response_length = 256
61
+
62
+ if "system_message" not in st.session_state:
63
+ st.session_state.system_message = "friendly AI conversing with a human user"
64
+
65
+ if "starter_message" not in st.session_state:
66
+ st.session_state.starter_message = "Hello, there! How can I help you today?"
67
+
68
+ if "uploaded_image_path" not in st.session_state:
69
+ st.session_state.uploaded_image_path = None
70
+
71
+ # Sidebar for settings
72
+ with st.sidebar:
73
+ st.header("System Settings")
74
+ st.session_state.system_message = st.text_area(
75
+ "System Message", value="You are a friendly AI conversing with a human user."
76
+ )
77
+ st.session_state.starter_message = st.text_area(
78
+ 'First AI Message', value="Hello, there! How can I help you today?"
79
+ )
80
+ st.session_state.max_response_length = st.number_input(
81
+ "Max Response Length", value=128
82
+ )
83
+ st.markdown("*Select Avatars:*")
84
+ col1, col2 = st.columns(2)
85
+ with col1:
86
+ st.session_state.avatars['assistant'] = st.selectbox(
87
+ "AI Avatar", options=["πŸ€—", "πŸ’¬", "πŸ€–"], index=0
88
+ )
89
+ with col2:
90
+ st.session_state.avatars['user'] = st.selectbox(
91
+ "User Avatar", options=["πŸ‘€", "πŸ‘±β€β™‚οΈ", "πŸ‘¨πŸΎ", "πŸ‘©", "πŸ‘§πŸΎ"], index=0
92
+ )
93
+ reset_history = st.button("Reset Chat History")
94
+
95
+ # Initialize or reset chat history
96
+ if "chat_history" not in st.session_state or reset_history:
97
+ st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
98
+
99
+ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
100
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
101
+ if hf is None:
102
+ return "Error with model inference.", chat_history
103
+
104
+ prompt = PromptTemplate.from_template(
105
+ "[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
106
+ )
107
+ chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
108
+ response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
109
+ response = response.split("AI:")[-1]
110
+
111
+ chat_history.append({'role': 'user', 'content': user_text})
112
+ chat_history.append({'role': 'assistant', 'content': response})
113
+ return response, chat_history
114
+
115
+ # Chat interface
116
+ chat_interface = st.container()
117
+ with chat_interface:
118
+ output_container = st.container()
119
+
120
+ # Image upload and captioning
121
+ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
122
+ if uploaded_image and st.session_state.uploaded_image_path is None:
123
+ # Save the uploaded image to a session-local directory
124
+ with st.spinner("Processing image... 0%"):
125
+ image = Image.open(uploaded_image).convert("RGB")
126
+
127
+ # Create a directory for session images if not exists
128
+ if not os.path.exists("session_images"):
129
+ os.makedirs("session_images")
130
+
131
+ # Save image to local session directory
132
+ image_path = os.path.join("session_images", uploaded_image.name)
133
+ image.save(image_path)
134
+
135
+ # Generate and save caption
136
+ caption = generate_caption(image)
137
+ st.session_state.chat_history.append({'role': 'user', 'content': f'![uploaded image]({image_path})'})
138
+ st.session_state.chat_history.append({'role': 'assistant', 'content': caption})
139
+ st.spinner("Processing image... 100%")
140
+
141
+ st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
142
+
143
+ if st.session_state.user_text:
144
+ with st.chat_message("user", avatar=st.session_state.avatars['user']):
145
+ st.markdown(st.session_state.user_text)
146
+ with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
147
+ response, st.session_state.chat_history = get_response(
148
+ system_message=st.session_state.system_message,
149
+ chat_history=st.session_state.chat_history,
150
+ user_text=st.session_state.user_text,
151
+ max_new_tokens=st.session_state.max_response_length
152
+ )
153
+ st.markdown(response)
154
+ st.spinner("Thinking... 100%")
155
+