Spaces:
Sleeping
Sleeping
File size: 6,560 Bytes
2faf743 43b366f 2faf743 960ede5 43b366f 960ede5 2144d43 2faf743 fe5b216 2faf743 fe5b216 2faf743 917fd95 2faf743 65b9ed7 2faf743 917fd95 2faf743 4f60d0b 2faf743 7350606 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import time
import gradio as gr
import logging
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
import whisper_app
import llm_ops
FILE_EXT = ['wav','mp3']
MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.1
def create_logger():
formatter = logging.Formatter('%(asctime)s:%(levelname)s:- %(message)s')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger = logging.getLogger("APT_Realignment")
logger.setLevel(logging.INFO)
if not logger.hasHandlers():
logger.addHandler(console_handler)
logger.propagate = False
return logger
def create_prompt():
prompt_template = """Asnwer the questions regarding the content in the Audio .
Use the following context to answer.
If you don't know the answer, just say I don't know.
{context}
Question: {question}
Answer :"""
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
return prompt
logger = create_logger()
def process_documents(documents,data_chunk=1500,chunk_overlap=100):
text_splitter = CharacterTextSplitter(chunk_size=data_chunk, chunk_overlap=chunk_overlap,separator='\n')
texts = text_splitter.split_documents(documents)
return texts
def audio_processor(wav_file,API_key,wav_model='small',llm='HuggingFace',temperature=0.1,max_tokens=4096):
device='cpu'
logger.info("Audio File Name :",wav_file.name)
whisper = whisper_app.WHISPERModel(model_name=wav_model,device=device)
logger.info("Whisper Model Loaded || Model size:{}".format(wav_model))
text_info = whisper.speech_to_text(audio_path=wav_file.name)
metadata = {"source": f"{wav_file}","duration":text_info['duration'],"language":text_info['language']}
document = [Document(page_content=text_info['text'], metadata=metadata)]
logger.info("Document",document)
logging.info("Loading General Text Embeddings (GTE) model{}".format('thenlper/gte-large'))
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large',model_kwargs={"device": device})
texts = process_documents(documents=document)
global vector_db
vector_db = FAISS.from_documents(documents=texts, embedding= embedding_model)
global qa
if llm == 'HuggingFace':
chat = llm_ops.get_hugging_face_model(
model_id="meta-llama/Llama-2-7b",
API_key=API_key,
temperature=temperature,
max_tokens=max_tokens
)
else:
chat = llm_ops.get_openai_chat_model(API_key=API_key)
chain_type_kwargs = {"prompt": create_prompt()}
qa = RetrievalQA.from_chain_type(llm=chat,
chain_type='stuff',
retriever=vector_db.as_retriever(),
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True
)
return "Audio Processing completed ..."
def infer(question, history):
# res = []
# for human, ai in history[:-1]:
# pair = (human, ai)
# res.append(pair)
# chat_history = res
result = qa({"query": question})
matching_docs_score = vector_db.similarity_search_with_score(question)
logger.info("Matching Score :",matching_docs_score)
return result["result"]
def bot(history):
response = infer(history[-1][0], history)
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.05)
yield history
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def loading_file():
return "Loading..."
css="""
#col-container {max-width: 2048px; margin-left: auto; margin-right: auto;}
"""
title = """
<div style="text-align: center;max-width: 2048px;">
<h1>Q&A using LLAMA on Audio files</h1>
<p style="text-align: center;">Upload a Audio file/link and query LLAMA-chatbot.
<i> Tools uses State of the Art Models from HuggingFace/OpenAI so, make sure to add your key.</i>
</p>
</div>
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(elem_id="col-container"):
gr.HTML(title)
with gr.Column():
with gr.Row():
LLM_option = gr.Dropdown(['HuggingFace','OpenAI'],label='Select HuggingFace/OpenAI')
API_key = gr.Textbox(label="Add API key", type="password",autofocus=True)
wav_model = gr.Dropdown(['small','medium','large'],label='Select Whisper model')
with gr.Group():
chatbot = gr.Chatbot(height=270)
with gr.Row():
question = gr.Textbox(label="Type your question !",lines=1,interactive=True)
with gr.Row():
submit_btn = gr.Button(value="Send message", variant="primary", scale = 1)
clean_chat_btn = gr.Button("Delete Chat")
with gr.Column():
with gr.Box():
audio_file = gr.File(label="Upload Audio File ", file_types=FILE_EXT, type="file")
with gr.Accordion(label='Advanced options', open=False):
max_new_tokens = gr.Slider(
label='Max new tokens',
minimum=2048,
maximum=MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label='Temperature',
minimum=0.1,
maximum=4.0,
step=0.1,
value=DEFAULT_TEMPERATURE,
)
with gr.Row():
langchain_status = gr.Textbox(label="Status", placeholder="", interactive = False)
load_audio = gr.Button("Upload Audio File")
if audio_file:
load_audio.click(loading_file, None, langchain_status, queue=False)
load_audio.click(audio_processor, inputs=[audio_file,API_key,wav_model,LLM_option,temperature,max_new_tokens], outputs=[langchain_status], queue=False)
demo.launch() |