|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import streamlit as st |
|
from streamlit_chat import message |
|
from langchain.chains import ConversationChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.llms.huggingface_hub import HuggingFaceHub |
|
from typing import Dict |
|
import json |
|
from io import StringIO |
|
from random import randint |
|
from transformers import AutoTokenizer |
|
|
|
st.set_page_config(page_title="Document Analysis", page_icon=":robot:") |
|
st.header("Chat with your document π (Model: Falcon-7B-Instruct)") |
|
|
|
model_name = "tiiuae/falcon-7b-instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
class CustomHuggingFaceEndpoint(HuggingFaceHub): |
|
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: |
|
len_prompt = len(prompt) |
|
input_str = json.dumps({ |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": 100, |
|
"stop": ["Human:"], |
|
"do_sample": False, |
|
"repetition_penalty": 1.1 |
|
} |
|
}) |
|
return input_str.encode('utf-8') |
|
|
|
def transform_output(self, output: bytes) -> str: |
|
response_json = output.decode('utf-8') |
|
res = json.loads(response_json) |
|
ans = res[0]['generated_text'][self.len_prompt:] |
|
ans = ans[:ans.rfind("Human")].strip() |
|
return ans |
|
|
|
|
|
def load_chain(): |
|
llm = CustomHuggingFaceEndpoint(repo_id=model_name) |
|
memory = ConversationBufferMemory() |
|
chain = ConversationChain(llm=llm, memory=memory) |
|
return chain |
|
|
|
chatchain = load_chain() |
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = [] |
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = [] |
|
chatchain.memory.clear() |
|
if 'widget_key' not in st.session_state: |
|
st.session_state['widget_key'] = str(randint(1000, 100000000)) |
|
|
|
|
|
st.sidebar.title("Sidebar") |
|
clear_button = st.sidebar.button("Clear Conversation", key="clear") |
|
|
|
if clear_button: |
|
st.session_state['generated'] = [] |
|
st.session_state['past'] = [] |
|
st.session_state['widget_key'] = str(randint(1000, 100000000)) |
|
chatchain.memory.clear() |
|
|
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload a txt file", type=["txt"], key=st.session_state['widget_key']) |
|
|
|
|
|
response_container = st.container() |
|
|
|
container = st.container() |
|
|
|
with container: |
|
|
|
with st.form(key='my_form', clear_on_submit=True): |
|
user_input = st.text_area("You:", key='input', height=100) |
|
submit_button = st.form_submit_button(label='Send') |
|
|
|
|
|
if submit_button and user_input: |
|
output = chatchain(user_input)["response"] |
|
st.session_state['past'].append(user_input) |
|
st.session_state['generated'].append(output) |
|
|
|
elif uploaded_file is not None: |
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
content = "=== BEGIN FILE ===\n" |
|
content += stringio.read().strip() |
|
content += "\n=== END FILE ===\nPlease confirm that you have read that file by saying 'Yes, I have read the file'" |
|
output = chatchain(content)["response"] |
|
st.session_state['past'].append("I have uploaded a file. Please confirm that you have read that file.") |
|
st.session_state['generated'].append(output) |
|
|
|
history = chatchain.memory.load_memory_variables({})["history"] |
|
tokens = tokenizer.tokenize(history) |
|
st.write(f"Number of tokens in memory: {len(tokens)}") |
|
|
|
|
|
if st.session_state['generated']: |
|
with response_container: |
|
for i in range(len(st.session_state['generated'])): |
|
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(i)) |