krystv commited on
Commit
8d28b9a
1 Parent(s): 1da781e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -147
app.py CHANGED
@@ -1,152 +1,220 @@
1
- import subprocess
2
- import sys
3
-
4
- subprocess.check_call([sys.executable,"-m","pip","install",'causal-conv1d'])
5
- subprocess.check_call([sys.executable, "-m", "pip", "install", 'miditok','mamba-ssm','gradio'])
6
- subprocess.check_call(["apt-get", "install", "timidity", "-y"])
7
-
8
- # !pip install pretty_midi midi2audio
9
- # !pip install miditok
10
- # !apt-get install fluidsynth
11
- # !apt install timidity -y
12
- # !pip install causal-conv1d>=1.1.0
13
- # !pip install mamba-ssm
14
- # !pip install gradio
15
-
16
-
17
-
18
- # !export LC_ALL="en_US.UTF-8"
19
- # !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
20
- # !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
21
-
22
- # subprocess.check_call(['export', 'LC_ALL="en_US.UTF-8"'])
23
- # subprocess.check_call(['export', 'LD_LIBRARY_PATH="/usr/lib64-nvidia"'])
24
- # subprocess.check_call(['export', 'LIBRARY_PATH="/usr/local/cuda/lib64/stubs"'])
25
- import os
26
-
27
- os.environ['LC_ALL'] = "en_US.UTF-8"
28
- os.environ['LD_LIBRARY_PATH'] = "/usr/lib64-nvidia"
29
- os.environ['LIBRARY_PATH'] = "/usr/local/cuda/lib64/stubs"
30
-
31
-
32
-
33
  import gradio as gr
34
- import torch
35
- from mamba_ssm import Mamba
36
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
37
- from mamba_ssm.models.config_mamba import MambaConfig
38
- import numpy as np
39
-
40
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
41
- if torch.cuda.is_available():
42
- subprocess.check_call(['ldconfig', '/usr/lib64-nvidia'])
43
- # !ldconfig /usr/lib64-nvidia
44
-
45
- # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt"
46
- # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json"
47
- if os.path.isfile("MIDI_Mamba-159M_1536VS.pt") == False:
48
- subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt'])
49
-
50
- if os.path.isfile("tokenizer_1536mix_BPE.json") == False:
51
- subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json'])
52
-
53
-
54
-
55
- mc = MambaConfig()
56
- mc.d_model = 768
57
- mc.n_layer = 42
58
- mc.vocab_size = 1536
59
-
60
- from miditok import MIDILike,REMI,TokenizerConfig
61
- from pathlib import Path
62
- import torch
63
-
64
- tokenizer = REMI(params='tokenizer_1536mix_BPE.json')
65
-
66
-
67
-
68
- mf = MambaLMHeadModel(config=mc,device=device)
69
- mf.load_state_dict(torch.load("/content/MIDI_Mamba-159M_1536VS.pt",map_location=device))
70
-
71
-
72
-
73
- twitter_follow_link = "https://twitter.com/iamhemantindia"
74
- instagram_follow_link = "https://instagram.com/iamhemantindia"
75
-
76
- custom_html = f"""
77
- <div style='text-align: center;'>
78
- <a href="{twitter_follow_link}" target="_blank" style="margin-right: 5px;">
79
- <img src="https://img.icons8.com/fluent/24/000000/twitter.png" alt="Follow on Twitter"/>
80
- </a>
81
- <a href="{instagram_follow_link}" target="_blank">
82
- <img src="https://img.icons8.com/fluent/24/000000/instagram-new.png" alt="Follow on Instagram"/>
83
- </a>
84
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
 
87
 
88
- @spaces.GPU(duration=120)
89
- def generate(number,top_k_selector,top_p_selector, temperature_selector):
90
- input_ids = torch.tensor([[1,]]).to(device)
91
- out = mf.generate(
92
- input_ids=input_ids,
93
- max_length=int(number),
94
- temperature=temperature_selector,
95
- top_p=top_p_selector,
96
- top_k=top_k_selector,
97
-
98
- eos_token_id=2,)
99
- m = tokenizer.decode(np.array(out[0].to('cpu')))
100
- np.array(out.to('cpu')).shape
101
- m.dump_midi('output.mid')
102
- # !timidity output.mid -Ow -o - | ffmpeg -y -f wav -i - output.mp3
103
- timidity_cmd = ['timidity', 'output.mid', '-Ow', '-o', 'output.wav']
104
- subprocess.check_call(timidity_cmd)
105
-
106
- # Then convert the WAV to MP3 using ffmpeg
107
- ffmpeg_cmd = ['ffmpeg', '-y', '-f', 'wav', '-i', 'output.wav', 'output.mp3']
108
- subprocess.check_call(ffmpeg_cmd)
109
-
110
- return "output.mp3"
111
-
112
-
113
- # text_box = gr.Textbox(label="Enter Text")
114
-
115
-
116
- def generate_and_save(number,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid):
117
- output_audio = generate(number,top_k_selector,top_p_selector, temperature_selector)
118
- return gr.Audio(output_audio,autoplay=True),gr.File(label="Download MIDI",value="output.mid"),generate_button
119
-
120
-
121
-
122
-
123
-
124
-
125
- # iface = gr.Interface(fn=generate_and_save,
126
- # inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],
127
- # outputs=[output_box,download_midi_button],
128
- # title="MIDI Mamba-159M",submit_btn=False,
129
- # clear_btn=False,
130
- # description="MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model.",
131
- # allow_flagging=False,)
132
-
133
- with gr.Blocks() as b1:
134
- gr.Markdown("<h1 style='text-align: center;'>MIDI Mamba-159M <h1/> ")
135
- gr.Markdown("<h3 style='text-align: center;'>MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model. <br> by Hemant Kumar<h3/>")
136
  with gr.Row():
137
- with gr.Column():
138
- number_selector = gr.Number(label="Select Length of output",value=512)
139
- top_p_selector = gr.Slider(label="Select Top P", minimum=0, maximum=1.0, step=0.05, value=0.9)
140
- temperature_selector = gr.Slider(label="Select Temperature", minimum=0, maximum=1.0, step=0.1, value=0.9)
141
- top_k_selector = gr.Slider(label="Select Top K", minimum=1, maximum=1536, step=1, value=30)
142
- generate_button = gr.Button(value="Generate",variant="primary")
143
- custom_html_wid = gr.HTML(custom_html)
144
- with gr.Column():
145
- output_box = gr.Audio("output.mp3",autoplay=True,)
146
- download_midi_button = gr.File(label="Download MIDI")
147
- generate_button.click(generate_and_save,inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],outputs=[output_box,download_midi_button,generate_button])
148
-
149
-
150
-
151
-
152
- b1.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ import os
4
+ import shutil
5
+ import streamlit as st
6
+ openai_api = st.secrets["OPENAI_API_KEY"]
7
+
8
+ doc_store_path = os.path.join(os.path.dirname(__file__), "doc_dir")
9
+ if not os.path.isdir(doc_store_path):
10
+ os.makedirs(doc_store_path)
11
+
12
+ from llama_index.core import SimpleDirectoryReader, VectorStoreIndex,Settings
13
+ from llama_index.core.node_parser import SentenceSplitter,SemanticSplitterNodeParser
14
+ from llama_index.llms.openai import OpenAI
15
+ from llama_index.llms.openai import OpenAI as OpenAIsum
16
+ from llama_index.embeddings.openai import OpenAIEmbedding
17
+ from llama_index.core.storage import StorageContext
18
+ from llama_index.vector_stores.chroma import ChromaVectorStore
19
+ from llama_index.core.storage.chat_store import SimpleChatStore
20
+ from llama_index.core.memory import ChatMemoryBuffer,ChatSummaryMemoryBuffer
21
+
22
+ import json
23
+ import chromadb
24
+ import tiktoken
25
+
26
+
27
+ chat_store = SimpleChatStore()
28
+ # chat_memory = ChatMemoryBuffer.from_defaults(
29
+ # token_limit=3000,
30
+ # chat_store=chat_store,
31
+ # chat_store_key="user1",
32
+ # )
33
+
34
+
35
+ sum_llm = OpenAIsum(api_key=openai_api, model="gpt-3.5-turbo", max_tokens=256)
36
+ chat_summary_memory = ChatSummaryMemoryBuffer.from_defaults(
37
+ token_limit=256,
38
+ chat_store=chat_store,
39
+ chat_store_key="user1",
40
+ llm = sum_llm,
41
+ tokenizer_fn = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
42
+ )
43
+
44
+
45
+ chat_store = SimpleChatStore.from_persist_path(
46
+ persist_path="chat_store.json"
47
+ )
48
+
49
+
50
+
51
+ # documents = SimpleDirectoryReader("./data").load_data()
52
+ db = chromadb.PersistentClient(path="./chroma_db")
53
+
54
+ chroma_collection = db.get_or_create_collection("quickstart")
55
+
56
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
57
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
58
+
59
+ Settings.llm = OpenAI(model="gpt-3.5-turbo",api_key=openai_api,)
60
+ Settings.embed_model = OpenAIEmbedding(model="text-embedding-ada-002")
61
+
62
+ vector_index = VectorStoreIndex.from_vector_store(vector_store, storage_context=storage_context,)
63
+ query_engine = vector_index.as_chat_engine(chat_memory=chat_summary_memory,storage_context=storage_context,use_async=True,similarity_top_k=2)
64
+
65
+ current_refs = ""
66
+
67
+ def metadata_from_doc(vec_index: VectorStoreIndex) -> dict:
68
+ qe = vec_index.as_chat_engine()
69
+ # f_prompt = """
70
+ # Given the text excerpts, analyze and provide the document's title and creation date in a structured JSON format. Here are a few examples:
71
+
72
+ # In this format:
73
+
74
+ # {
75
+ # "creation_date": "YYYY-MM-DD",
76
+ # "title": "Title of the Document"
77
+ # }
78
+
79
+ # Text: 'An analysis of historical events. Written by Alex Johnson on 5 March 2019.'
80
+ # Output: { "title": "An analysis of historical events", "creation_date": "2019-03-05" }
81
+
82
+ # Text: 'Exploring the depths of the ocean. This comprehensive guide was authored by Dr. Emily White, published on 10-July 2021.'
83
+ # Output: { "title": "Exploring the depths of the ocean", "creation_date": "2021-07-10" }
84
+
85
+ # Text: 'The history of the Roman Empire.'
86
+ # Output: { "title": "The history of the Roman Empire", "creation_date": "Unknown" }
87
+
88
+
89
+ # Now, analyze the context from the provided document and generate json object.
90
+ # """
91
+ f_prompt ="""give me a only the data when this document was written and title of this document? in json format parameter (created_date,title),
92
+ example context: 'An analysis of historical events. Written by Alex Johnson on 5 March 2019.'
93
+ example output: { "title": "An analysis of historical events", "creation_date": "2019-03-05" }
94
+ now analyse the context make sure to return output only in json format object only.
95
+ """
96
+ res = qe.query(f_prompt)
97
+ parsed = json.loads(res.response)
98
+ return parsed
99
+
100
+ def filter_unsaved(file_paths:list):
101
+ for i in file_paths:
102
+ if os.path.isfile(os.path.join(doc_store_path,os.path.basename(i))):
103
+ file_paths.remove(i)
104
+ print("File already exist : {}".format(i))
105
+ else:
106
+ shutil.copy2(i,doc_store_path)
107
+ return file_paths
108
+
109
+ def add_doc(file_paths:list):
110
+ print(file_paths)
111
+ file_paths = filter_unsaved(file_paths)
112
+ print(file_paths)
113
+ if len(file_paths) == 0:
114
+ return
115
+ docs = SimpleDirectoryReader(input_files=file_paths).load_data()
116
+ splitter = SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=95, embed_model=Settings.embed_model,chunk_size=256)
117
+ nodes = splitter.get_nodes_from_documents(docs)
118
+ vector_index2 = VectorStoreIndex(nodes)
119
+ for i in range (5):
120
+ try:
121
+ meta = metadata_from_doc(vector_index2)
122
+ break
123
+ except:
124
+ meta = {
125
+ "title": "Unknown",
126
+ "creation_date": "Unknown"
127
+ }
128
+ continue
129
+
130
+ print(meta)
131
+ for i in range(len(nodes)):
132
+ nodes[i].metadata.update(meta)
133
+ vector_index.insert_nodes(nodes)
134
+
135
+
136
+
137
+
138
+
139
+ CSS ="""
140
+ .contain { display: flex; flex-direction: column; }
141
+ .gradio-container { height: 100vh !important; }
142
+ #component-0 { height: 100%; }
143
+ #chatbot { flex-grow: 1; overflow: auto;}
144
  """
145
 
146
 
147
+ def new_chat(chatbot:gr.Chatbot,textbox):
148
+ query_engine.reset()
149
+ return gr.update(value=""),[],"",gr.File(visible=False),gr.File(visible=False)
150
+
151
+
152
+ def chat(history, input):
153
+ response = query_engine.chat(str(input))
154
+ global current_refs
155
+ files = []
156
+ current_refs = ""
157
+ for node in response.source_nodes:
158
+ try:
159
+ current_refs += f"{str(node.metadata['title'])},"
160
+ except:
161
+ current_refs += ""
162
+ try:
163
+ current_refs += f"Pg - {str(node.metadata['page_label'])},"
164
+ except:
165
+ current_refs += "Pg - ,"
166
+ try:
167
+ current_refs += f"File - {str(node.metadata['file_name'])} \n,"
168
+ except:
169
+ current_refs += "File - ,\n"
170
+
171
+ try:
172
+ files.append({'path':node.metadata['file_path'],'show':True,})
173
+ except:
174
+ files.append({'path':None,'show':False,})
175
+
176
+ if len(files) < 2:
177
+ for _ in range(2-len(files)):
178
+ files.append({'path':None,'show':False,})
179
+
180
+ return gr.update(value=""),history + [(input, response.response)],current_refs,gr.update(visible=files[0]['show'],value=files[0]['path']),gr.update(visible=files[1]['show'],value=files[1]['path'])
181
+
182
+ def file_upload(file,chatbot):
183
+ print(file)
184
+ add_doc(file)
185
+ return gr.update(value="ChatDoc"),chatbot
186
+
187
+ with gr.Blocks(fill_height=True, css=CSS) as demo:
 
 
 
 
 
 
 
188
  with gr.Row():
189
+ with gr.Column(scale=1):
190
+ title = gr.Label(value="chatdoc", label="ChatDoc")
191
+ files = gr.UploadButton(
192
+ "📁 Upload PDF or doc files", file_types=[
193
+ '.pdf',
194
+ '.doc'
195
+ ],
196
+ file_count="multiple")
197
+ references = gr.Textbox(label="References",interactive=False)
198
+ file_down1 = gr.File(visible=False)
199
+ file_down2 = gr.File(visible=False)
200
+
201
+
202
+
203
+ with gr.Column(scale=9,):
204
+ chatbot = gr.Chatbot(
205
+ elem_id="chatbot",
206
+ bubble_full_width=False,
207
+ label="ChatDoc",
208
+ avatar_images=["https://www.freeiconspng.com/thumbs/person-icon-blue/person-icon-blue-25.png","https://cdn-icons-png.flaticon.com/512/8943/8943377.png"],
209
+ )
210
+ with gr.Row():
211
+ textbox = gr.Textbox(label="Type your message", scale=10)
212
+ clear = gr.Button(value="New Chat", size="sm", scale=1)
213
+ clear.click(new_chat,[],[textbox, chatbot,references,file_down1,file_down2])
214
+ textbox.submit(chat, [chatbot, textbox], [textbox, chatbot,references,file_down1,file_down2])
215
+
216
+
217
+ files.upload(file_upload,[files,chatbot],[title,chatbot])
218
+
219
+
220
+ demo.launch(share=True)