llamazookeeper commited on
Commit
3d803ed
β€’
1 Parent(s): d36b7c3
Files changed (2) hide show
  1. app.py +68 -115
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,116 +1,69 @@
1
- # Import streamlit for app dev
 
2
  import streamlit as st
3
-
4
- # Import transformer classes for generaiton
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
6
- # Import torch for datatype attributes
7
- import torch
8
- # Import the prompt wrapper...but for llama index
9
- from llama_index.prompts.prompts import SimpleInputPrompt
10
- # Import the llama index HF Wrapper
11
- from llama_index.llms import HuggingFaceLLM
12
- # Bring in embeddings wrapper
13
- from llama_index.embeddings import LangchainEmbedding
14
- # Bring in HF embeddings - need these to represent document chunks
15
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
16
- # Bring in stuff to change service context
17
- from llama_index import set_global_service_context
18
- from llama_index import ServiceContext
19
- # Import deps to load documents
20
- from llama_index import VectorStoreIndex, download_loader
21
- from pathlib import Path
22
-
23
-
24
-
25
- # Define variable to hold llama2 weights naming
26
- #name = "meta-llama/Llama-2-70b-chat-hf"
27
- name = "mistralai/Mistral-7B-v0.1"
28
- # Set auth token variable from hugging face
29
- #auth_token = "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" #llamazookeeper
30
- auth_token = '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'
31
-
32
-
33
-
34
- @st.cache_resource
35
- def get_tokenizer_model():
36
- # Create tokenizer
37
- tokenizer = AutoTokenizer.from_pretrained(name, cache_dir='./model/', use_auth_token=auth_token)
38
-
39
- # Create model
40
- model = AutoModelForCausalLM.from_pretrained(name, cache_dir='./model/'
41
- , use_auth_token=auth_token, torch_dtype=torch.float16,
42
- rope_scaling={"type": "dynamic", "factor": 2}, load_in_8bit=True)
43
-
44
- return tokenizer, model
45
- tokenizer, model = get_tokenizer_model()
46
-
47
- # Create a system prompt
48
- system_prompt = """<s>[INST] <<SYS>>
49
- You are a helpful, respectful and honest assistant. Always answer as
50
- helpfully as possible, while being safe. Your answers should not include
51
- any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
52
- Please ensure that your responses are socially unbiased and positive in nature.
53
-
54
- If a question does not make any sense, or is not factually coherent, explain
55
- why instead of answering something not correct. If you don't know the answer
56
- to a question, please don't share false information.
57
-
58
- Your goal is to provide answers relating to the financial performance of
59
- the company.<</SYS>>
60
- """
61
- # Throw together the query wrapper
62
- query_wrapper_prompt = SimpleInputPrompt("{query_str} [/INST]")
63
-
64
- # Create a HF LLM using the llama index wrapper
65
- llm = HuggingFaceLLM(context_window=4096,
66
- max_new_tokens=256,
67
- system_prompt=system_prompt,
68
- query_wrapper_prompt=query_wrapper_prompt,
69
- model=model,
70
- tokenizer=tokenizer)
71
-
72
- # Create and dl embeddings instance
73
- embeddings=LangchainEmbedding(
74
- HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
75
- )
76
-
77
- # Create new service context instance
78
- service_context = ServiceContext.from_defaults(
79
- chunk_size=1024,
80
- llm=llm,
81
- embed_model=embeddings
82
- )
83
- # And set the service context
84
- set_global_service_context(service_context)
85
-
86
- # Download PDF Loader
87
- PyMuPDFReader = download_loader("PyMuPDFReader")
88
- # Create PDF Loader
89
- loader = PyMuPDFReader()
90
- # Load documents
91
- documents = loader.load(file_path=Path('/content/*.pdf'), metadata=True)
92
-
93
- # Create an index - we'll be able to query this in a sec
94
- index = VectorStoreIndex.from_documents(documents)
95
- # Setup index query engine using LLM
96
- query_engine = index.as_query_engine()
97
-
98
-
99
- # Create centered main title
100
- #st.title('πŸ¦™ Llama Banker')
101
- # Create a text input box for the user
102
- prompt = st.text_input('Input your prompt here')
103
-
104
- # If the user hits enter
105
- if prompt:
106
- response = query_engine.query(prompt)
107
- # ...and write it out to the screen
108
- st.write(response)
109
-
110
- # Display raw response object
111
- with st.expander('Response Object'):
112
- st.write(response)
113
- # Display source text
114
- with st.expander('Source Text'):
115
- st.write(response.get_formatted_sources())
116
-
 
1
+ # import
2
+ from tensorflow.python.keras.utils.generic_utils import default
3
  import streamlit as st
4
+ from newspaper import Article
5
+ from transformers import pipeline
6
+
7
+ # set config
8
+ st.set_page_config(layout="wide", page_title="SummarizeLink")
9
+
10
+ # load the summarization model (cache for faster loading)
11
+ @st.cache(allow_output_mutation=True)
12
+ def load_summarize_model():
13
+ # model = pipeline("summarization", model='sshleifer/distilbart-cnn-12-6')
14
+ model = pipeline("summarization")
15
+ return model
16
+
17
+ # loading the model
18
+ summ = load_summarize_model()
19
+
20
+ # define the down functions
21
+ def download_and_parse_article(url):
22
+ """Downloads and parses an article from a URL.
23
+ Parameters
24
+ ----------
25
+ url : str
26
+ The URL of the article to download and parse.
27
+ Returns
28
+ -------
29
+ article : newspaper.Article
30
+ The article downloaded and parsed.
31
+ """
32
+ # define the article
33
+ article = Article(url)
34
+ # download and parse the article
35
+ article.download()
36
+ article.parse()
37
+ # return the article
38
+ return article.text
39
+
40
+ # APP
41
+ # set title and subtitle
42
+ st.title("SummarizeLink")
43
+ st.markdown("Paste any article link below and click on the 'Summarize' button.")
44
+ st.markdown("*Note:* We truncate the text incase the article is lengthy! πŸ––")
45
+ # create the input text box and setting panel
46
+ link = st.text_area('Paste your link here...', "https://towardsdatascience.com/a-guide-to-the-knowledge-graphs-bfb5c40272f1", height=50)
47
+ button = st.button("Summarize")
48
+ min_length = st.sidebar.slider('Min summary length', min_value=10, max_value=100, value=50, step=10)
49
+ max_length = st.sidebar.slider('Max summary length', min_value=30, max_value=700, value=100, step=10)
50
+ num_beams = st.sidebar.slider('Beam length', min_value=1, max_value=10, value=5, step=1)
51
+
52
+ # if button is clicked
53
+ with st.spinner("Parsing article and Summarizing..."):
54
+ if button and link:
55
+ # get the text
56
+ text = download_and_parse_article(link)
57
+ # summarize the text
58
+ summary = summ(text,
59
+ truncation=True,
60
+ max_length = max_length,
61
+ min_length = min_length,
62
+ num_beams=num_beams,
63
+ do_sample=True,
64
+ early_stopping=True,
65
+ repetition_penalty=1.5,
66
+ length_penalty=1.5)[0]
67
+ # display the summary
68
+ st.markdown("**Summary:**")
69
+ st.write(summary['summary_text'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,3 +5,4 @@ transformers
5
  accelerate
6
  bitsandbytes
7
  requests
 
 
5
  accelerate
6
  bitsandbytes
7
  requests
8
+ newspaper