dipta007 commited on
Commit
a602459
1 Parent(s): 55aeedd
Files changed (5) hide show
  1. .gitignore +178 -0
  2. app.py +99 -0
  3. requirements.txt +2 -0
  4. upload.py +27 -0
  5. utils.py +22 -0
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
177
+
178
+ .streamlit
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import streamlit as st
4
+ from transformers import Conversation, pipeline
5
+ from upload import get_file, upload_file
6
+ from utils import clear_uploader, undo, restart
7
+
8
+
9
+ share_keys = ["messages", "model_name"]
10
+ MODELS = ["facebook/blenderbot-400M-distill", "facebook/blenderbot-90M"]
11
+
12
+ st.set_page_config(
13
+ page_title="LLM",
14
+ page_icon="📚",
15
+ )
16
+
17
+ if "model_name" not in st.session_state:
18
+ st.session_state.model_name = "facebook/blenderbot-400M-distill"
19
+
20
+
21
+ def get_pipeline(model_name):
22
+ device = 0 if torch.cuda.is_available() else -1
23
+ chatbot = pipeline(model=model_name, task="conversational", device=device)
24
+ return chatbot
25
+
26
+ chatbot = get_pipeline(st.session_state.model_name)
27
+
28
+ if "messages" not in st.session_state:
29
+ st.session_state.messages = []
30
+
31
+ if len(st.session_state.messages) == 0 and "id" in st.query_params:
32
+ with st.spinner("Loading chat..."):
33
+ id = st.query_params["id"]
34
+ data = get_file(id, 'llm-007')
35
+ obj = pickle.loads(data)
36
+ for k, v in obj.items():
37
+ st.session_state[k] = v
38
+
39
+
40
+ def share():
41
+ obj = {}
42
+ for k in share_keys:
43
+ if k in st.session_state:
44
+ obj[k] = st.session_state[k]
45
+ data = pickle.dumps(obj)
46
+ id = upload_file(data, 'llm-007')
47
+ url = f"https://umbc-nlp-llm.hf.space/?id={id}"
48
+ st.markdown(f"[share](/?id={id})")
49
+ st.success(f"Share URL: {url}")
50
+
51
+ with st.sidebar:
52
+ st.title(":blue[LLM Only]")
53
+
54
+ st.subheader("Model")
55
+ model_name = st.selectbox("Model", MODELS, index=MODELS.index(st.session_state.model_name))
56
+
57
+ if st.button("Share", use_container_width=True):
58
+ share()
59
+
60
+ cols = st.columns(2)
61
+ with cols[0]:
62
+ if st.button("Restart", type="primary", use_container_width=True):
63
+ restart()
64
+
65
+ with cols[1]:
66
+ if st.button("Undo", use_container_width=True):
67
+ undo()
68
+
69
+ append = st.checkbox("Append to previous message", value=False)
70
+
71
+
72
+ for message in st.session_state.messages:
73
+ with st.chat_message(message["role"]):
74
+ st.markdown(message["content"])
75
+
76
+
77
+ def push_message(role, content):
78
+ message = {"role": role, "content": content}
79
+ st.session_state.messages.append(message)
80
+ return message
81
+
82
+ if prompt := st.chat_input("Type a message", key="chat_input"):
83
+ push_message("user", prompt)
84
+ with st.chat_message("user"):
85
+ st.markdown(prompt)
86
+
87
+ if not append:
88
+ with st.chat_message("assistant"):
89
+ conversation = Conversation()
90
+ for m in st.session_state.messages:
91
+ conversation.add_message(m)
92
+ print(conversation)
93
+ with st.spinner("Generating response..."):
94
+ response = chatbot(conversation)
95
+ response = response[-1]["content"]
96
+ st.write(response)
97
+
98
+ push_message("assistant", response)
99
+ clear_uploader()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ streamlit
2
+ boto3
upload.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import uuid
3
+ import boto3
4
+
5
+ s3 = boto3.client(
6
+ 's3',
7
+ aws_access_key_id=st.secrets["AWS_ACCESS_KEY"],
8
+ aws_secret_access_key=st.secrets["AWS_SECRET_KEY"]
9
+ )
10
+
11
+ def upload_file(data, bucket=st.secrets["S3_BUCKET"]):
12
+ file_name = uuid.uuid4().hex
13
+ try:
14
+ key = f"{file_name}.pkl"
15
+ response = s3.put_object(Body=data, Bucket=bucket, Key=key)
16
+ except Exception as e:
17
+ return None
18
+ return file_name
19
+
20
+
21
+ def get_file(file_name, bucket=st.secrets["S3_BUCKET"]):
22
+ try:
23
+ response = s3.get_object(Bucket=bucket, Key=f"{file_name}.pkl")
24
+ return response['Body'].read()
25
+ except Exception as e:
26
+ print(e)
27
+ return False
utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+
4
+
5
+ def clear_uploader():
6
+ st.rerun()
7
+
8
+ def undo():
9
+ if len(st.session_state.messages) > 0:
10
+ st.query_params.clear()
11
+ msg = st.session_state.messages.pop()
12
+ if msg["role"] == "assistant" and "cost" in st.session_state:
13
+ st.session_state.cost.pop()
14
+ time.sleep(0.1)
15
+ st.rerun()
16
+
17
+ def restart():
18
+ st.query_params.clear()
19
+ st.session_state.messages = []
20
+ st.session_state.cost = []
21
+ time.sleep(0.2)
22
+ clear_uploader()