noorulamean444's picture
Update app.py
1002a87 verified
raw
history blame
No virus
6.89 kB
import requests
import os
import gradio as gr
import time
import heapq
import re
from utils import package_installer
package_installer('sentence_transformers')
package_installer('nbconvert')
package_installer('inflect')
# package_installer('nbformat')
# package_installer('beautifulsoup4')
from sentence_transformers import SentenceTransformer, util
import nbconvert
import nbformat
from bs4 import BeautifulSoup
from inflect import engine
inflect_engine = engine()
def convert_ipynb_to_html(input_file):
# Load .ipynb file into a nbformat.NotebookNode object
notebook = nbformat.read(input_file, as_version=4)
# Convert using HTML exporter
html_exporter = nbconvert.HTMLExporter()
(body, resources) = html_exporter.from_notebook_node(notebook)
# Write to output html file
# with open(output_file, 'w') as f:
# f.write(body)
return body
API_TOKEN = os.environ['HF_TOKEN']
API_URL = "https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct"
headers = {"Authorization": f"Bearer {API_TOKEN}"}
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def user_prompt_template(user_msg:str):
return f"<|user|>\n{user_msg}<|end|>\n<|assistant|>"
def assistant_response_template(assistant_msg:str):
return f"{assistant_msg}<|end|>\n"
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload,timeout=120)
return response.json()
def chat(message,history):
formatted_user_msg = user_prompt_template(message['text'])
hist_cop = history.copy()
for item in history:
if None in item and type(item[0]) == tuple:
hist_cop.remove(item)
try:
html_code = convert_ipynb_to_html(message['files'][0])
soup = BeautifulSoup(html_code, 'html.parser')
text = soup.get_text()
code_data = text.split('\n')
string = ''
cells_list = []
for item in code_data:
if len(item) > 0:
string += item + '\n'
continue
if len(item) == 0 and len(string) > 0:
cells_list.append(string)
string = ''
cells_list_copy = cells_list.copy()
for item in cells_list:
if item == 'Notebook\n' or item == 'In\xa0[\xa0]:\n':
cells_list_copy.remove(item)
indexed_cells_list = []
index_comments = []
for i in range(len(cells_list_copy)):
itxt = cells_list_copy[i]
cell_addresses = f'# Cell Number: {i+1}\n' + f'# Cell Number: {inflect_engine.number_to_words(i+1)}\n'
if i+1 % 10 == 1:
indexed_cells_list.append(f'# {i+1}st cell\n'+ cell_addresses + itxt)
index_comments.append(f'# {i+1}st cell\n'+ cell_addresses)
elif i+1 % 10 == 2:
indexed_cells_list.append(f'# {i+1}nd cell\n' + cell_addresses + itxt)
index_comments.append(f'# {i+1}nd cell\n' + cell_addresses)
elif i+1 % 10 == 3:
indexed_cells_list.append(f'# {i+1}rd cell\n' + cell_addresses + itxt)
index_comments.append(f'# {i+1}rd cell\n' + cell_addresses)
else:
indexed_cells_list.append(f'# {i+1}th cell\n' + cell_addresses + itxt)
index_comments.append(f'# {i+1}th cell\n' + cell_addresses)
# cells = re.split(r"In\xa0\[[0-9\xa0]*\]:",text)
# cells = [element.strip() for element in cells]
# cells = [element for element in cells if element != '']
except:
pass
# print(cells)
# print()
# print(len(cells))
# cells_as_string = '\n'.join(cells)
emb_cells = embedding_model.encode(index_comments,convert_to_tensor=True)
emb_formatted_user_msg = embedding_model.encode(formatted_user_msg,convert_to_tensor=True)
cosine_sim_0 = util.cos_sim(emb_formatted_user_msg,emb_cells)
top_5_cells_scores = heapq.nlargest(5,cosine_sim_0[0])
top_5_cells = [indexed_cells_list[index] for index in sorted(list(cosine_sim_0[0]).index(score) for score in top_5_cells_scores)]
top_2_chats = None
if hist_cop:
chat_history = [user_prompt_template(item[0]) + assistant_response_template(item[1]) for item in hist_cop]
# emb_formatted_user_msg = embedding_model.encode(formatted_user_msg,convert_to_tensor=True)
emb_chat_history = embedding_model.encode(chat_history,convert_to_tensor=True)
cosine_similarity_scores = util.cos_sim(emb_formatted_user_msg,emb_chat_history)
top_2_scores = heapq.nlargest(2,cosine_similarity_scores[0])
top_2_chats = [chat_history[i] for i in sorted(list(cosine_similarity_scores[0]).index(val) for val in top_2_scores)]
similar_chat_history = ''
if top_2_chats:
for chats in top_2_chats:
# formatted_assistant_msg = chats[1].replace(chats[0],'').strip().removesuffix('<|end|>')
similar_chat_history += chats
#prompt = f"<|user|>\n{message}<|end|>\n<|assistant|>"
top_5_cells_string = '\n'.join(top_5_cells)
context_plus_message = top_5_cells_string + message['text']
formatted_context_plus_message = user_prompt_template(context_plus_message)
user_input = similar_chat_history + formatted_context_plus_message
# print(user_input)
# print('-'*20)
# print('\n')
inp_dict = {"inputs":user_input,
"parameters": {"max_new_tokens":750,"temperature":0.01}}
output = query(inp_dict)
#
try:
output_text = output[0]['generated_text']
formatted_assistant_msg = output_text.replace(user_input,'').strip().removesuffix('<|end|>')
except:
if type(output) == dict:
formatted_assistant_msg = f"Error has occured, type of output is {type(output)} and items of output are: {output.items()}"
else:
formatted_assistant_msg = f"Error has occured, type of output is {type(output)} and length of output is: {len(output)}"
print(user_input)
print()
print(indexed_cells_list)
return formatted_assistant_msg
demo = gr.ChatInterface(chat, multimodal=True)
if __name__ == '__main__':
demo.launch()
# import gradio as gr
# def process_file(file_path):
# # This function will be called when a file is uploaded.
# # 'file_path' is a string that contains the path to the uploaded file.
# # You can read the file using this path and process it as needed.
# # For example, you can return the name of the file:
# return f"You uploaded {file_path}"
# iface = gr.Interface(
# fn=process_file, # the function to call when a file is uploaded
# inputs=gr.File(), # creates a file upload button
# outputs="text" # the output of 'process_file' is text
# )
# iface.launch()