Reformat
Browse files- app.py +7 -3
- main.py +2 -0
- mapping.py +1 -1
- memory.py +7 -6
- utils.py +7 -6
app.py
CHANGED
@@ -1,10 +1,12 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from main import index, run
|
3 |
from gtts import gTTS
|
4 |
-
import os, time, uuid
|
5 |
-
|
6 |
from transformers import pipeline
|
7 |
|
|
|
|
|
8 |
p = pipeline("automatic-speech-recognition", model="openai/whisper-base")
|
9 |
|
10 |
"""Use text to call chat method from main.py"""
|
@@ -14,9 +16,11 @@ models = ["GPT-3.5", "Flan UL2", "Flan T5"]
|
|
14 |
with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo:
|
15 |
state = gr.State([])
|
16 |
|
|
|
17 |
def create_session_id():
|
18 |
return str(uuid.uuid4())
|
19 |
|
|
|
20 |
def add_text(history, text, model):
|
21 |
print("Question asked: " + text)
|
22 |
response = run_model(text, model)
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
|
4 |
import gradio as gr
|
|
|
5 |
from gtts import gTTS
|
|
|
|
|
6 |
from transformers import pipeline
|
7 |
|
8 |
+
from main import index, run
|
9 |
+
|
10 |
p = pipeline("automatic-speech-recognition", model="openai/whisper-base")
|
11 |
|
12 |
"""Use text to call chat method from main.py"""
|
|
|
16 |
with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo:
|
17 |
state = gr.State([])
|
18 |
|
19 |
+
|
20 |
def create_session_id():
|
21 |
return str(uuid.uuid4())
|
22 |
|
23 |
+
|
24 |
def add_text(history, text, model):
|
25 |
print("Question asked: " + text)
|
26 |
response = run_model(text, model)
|
main.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
from utils import get_search_index, generate_answer, set_model_and_embeddings, set_session_id
|
2 |
|
|
|
3 |
def index(model, session_id):
|
4 |
set_session_id(session_id)
|
5 |
set_model_and_embeddings(model)
|
6 |
get_search_index(model)
|
7 |
return True
|
8 |
|
|
|
9 |
def run(question, model, session_id):
|
10 |
index(model, session_id)
|
11 |
return generate_answer(question)
|
|
|
1 |
from utils import get_search_index, generate_answer, set_model_and_embeddings, set_session_id
|
2 |
|
3 |
+
|
4 |
def index(model, session_id):
|
5 |
set_session_id(session_id)
|
6 |
set_model_and_embeddings(model)
|
7 |
get_search_index(model)
|
8 |
return True
|
9 |
|
10 |
+
|
11 |
def run(question, model, session_id):
|
12 |
index(model, session_id)
|
13 |
return generate_answer(question)
|
mapping.py
CHANGED
@@ -116,4 +116,4 @@ FILE_URL_MAPPING = {
|
|
116 |
'https://www.coursera.org/learn/3d-printing-revolution/supplement/HZXB5/module-1-overview',
|
117 |
|
118 |
'docs/02_module-1-what-is-3d-printing/02_3d-printing-insights/07_what-would-you-make-exercise_peer_assignment_instructions.html':
|
119 |
-
'https://www.coursera.org/learn/3d-printing-revolution/peer/t8bqq/what-would-you-make-exercise'}
|
|
|
116 |
'https://www.coursera.org/learn/3d-printing-revolution/supplement/HZXB5/module-1-overview',
|
117 |
|
118 |
'docs/02_module-1-what-is-3d-printing/02_3d-printing-insights/07_what-would-you-make-exercise_peer_assignment_instructions.html':
|
119 |
+
'https://www.coursera.org/learn/3d-printing-revolution/peer/t8bqq/what-would-you-make-exercise'}
|
memory.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
import json
|
2 |
-
from datetime import datetime
|
3 |
-
from pymongo import errors
|
4 |
-
from langchain.schema import AIMessage, BaseMessage, HumanMessage, messages_from_dict, _message_to_dict
|
5 |
-
from langchain.memory import MongoDBChatMessageHistory
|
6 |
import logging
|
|
|
7 |
from typing import List
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
|
|
|
|
11 |
class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory):
|
12 |
|
13 |
@property
|
@@ -43,7 +45,6 @@ class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory):
|
|
43 |
# Determine the sender based on the message type
|
44 |
sender = "ai" if isinstance(message, AIMessage) else "human"
|
45 |
|
46 |
-
|
47 |
# Create the message object with the desired format
|
48 |
message_obj = {
|
49 |
"type": sender,
|
@@ -59,4 +60,4 @@ class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory):
|
|
59 |
upsert=True
|
60 |
)
|
61 |
except errors.WriteError as err:
|
62 |
-
logger.error(err)
|
|
|
1 |
import json
|
|
|
|
|
|
|
|
|
2 |
import logging
|
3 |
+
from datetime import datetime
|
4 |
from typing import List
|
5 |
|
6 |
+
from langchain.memory import MongoDBChatMessageHistory
|
7 |
+
from langchain.schema import AIMessage, BaseMessage, HumanMessage, messages_from_dict, _message_to_dict
|
8 |
+
from pymongo import errors
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory):
|
14 |
|
15 |
@property
|
|
|
45 |
# Determine the sender based on the message type
|
46 |
sender = "ai" if isinstance(message, AIMessage) else "human"
|
47 |
|
|
|
48 |
# Create the message object with the desired format
|
49 |
message_obj = {
|
50 |
"type": sender,
|
|
|
60 |
upsert=True
|
61 |
)
|
62 |
except errors.WriteError as err:
|
63 |
+
logger.error(err)
|
utils.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import os
|
2 |
import pickle
|
3 |
-
import langchain
|
4 |
|
5 |
import faiss
|
|
|
6 |
from langchain import HuggingFaceHub
|
|
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
|
@@ -16,9 +17,9 @@ from langchain.prompts.chat import (
|
|
16 |
)
|
17 |
from langchain.text_splitter import CharacterTextSplitter
|
18 |
from langchain.vectorstores.faiss import FAISS
|
19 |
-
|
20 |
-
from memory import CustomMongoDBChatMessageHistory
|
21 |
from mapping import FILE_URL_MAPPING
|
|
|
22 |
|
23 |
langchain.llm_cache = InMemoryCache()
|
24 |
|
@@ -67,7 +68,8 @@ def set_session_id(session_id):
|
|
67 |
connection_string=MONGO_DB_URL, session_id=session_id, database_name='coursera_bots',
|
68 |
collection_name='printing_3d_revolution'
|
69 |
)
|
70 |
-
memory = ConversationBufferWindowMemory(memory_key="chat_history", chat_memory=message_history, k=10,
|
|
|
71 |
|
72 |
|
73 |
def set_model_and_embeddings(model):
|
@@ -86,7 +88,7 @@ def set_model(model):
|
|
86 |
llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
|
87 |
elif model == "Flan UL2":
|
88 |
print("Loading Flan-UL2")
|
89 |
-
llm = HuggingFaceHub(repo_id="google/flan-ul2", model_kwargs={"temperature": 0.1, "max_new_tokens":500})
|
90 |
elif model == "Flan T5":
|
91 |
print("Loading Flan T5")
|
92 |
llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.1})
|
@@ -230,7 +232,6 @@ def generate_answer(question) -> str:
|
|
230 |
|
231 |
|
232 |
def save_chat_history(question, result):
|
233 |
-
|
234 |
memory.chat_memory.add_user_message(question)
|
235 |
memory.chat_memory.add_ai_message(result["answer"])
|
236 |
print("chat history on saving: " + str(memory.chat_memory.messages))
|
|
|
1 |
import os
|
2 |
import pickle
|
|
|
3 |
|
4 |
import faiss
|
5 |
+
import langchain
|
6 |
from langchain import HuggingFaceHub
|
7 |
+
from langchain.cache import InMemoryCache
|
8 |
from langchain.chains import ConversationalRetrievalChain
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
|
|
|
17 |
)
|
18 |
from langchain.text_splitter import CharacterTextSplitter
|
19 |
from langchain.vectorstores.faiss import FAISS
|
20 |
+
|
|
|
21 |
from mapping import FILE_URL_MAPPING
|
22 |
+
from memory import CustomMongoDBChatMessageHistory
|
23 |
|
24 |
langchain.llm_cache = InMemoryCache()
|
25 |
|
|
|
68 |
connection_string=MONGO_DB_URL, session_id=session_id, database_name='coursera_bots',
|
69 |
collection_name='printing_3d_revolution'
|
70 |
)
|
71 |
+
memory = ConversationBufferWindowMemory(memory_key="chat_history", chat_memory=message_history, k=10,
|
72 |
+
return_messages=True)
|
73 |
|
74 |
|
75 |
def set_model_and_embeddings(model):
|
|
|
88 |
llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
|
89 |
elif model == "Flan UL2":
|
90 |
print("Loading Flan-UL2")
|
91 |
+
llm = HuggingFaceHub(repo_id="google/flan-ul2", model_kwargs={"temperature": 0.1, "max_new_tokens": 500})
|
92 |
elif model == "Flan T5":
|
93 |
print("Loading Flan T5")
|
94 |
llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.1})
|
|
|
232 |
|
233 |
|
234 |
def save_chat_history(question, result):
|
|
|
235 |
memory.chat_memory.add_user_message(question)
|
236 |
memory.chat_memory.add_ai_message(result["answer"])
|
237 |
print("chat history on saving: " + str(memory.chat_memory.messages))
|