SabariKameswaran's picture
Update app.py
4b02ce6 verified
raw
history blame contribute delete
No virus
2.15 kB
from flask import Flask, jsonify, request, send_file
from gtts import gTTS
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.indexes import VectorstoreIndexCreator
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
import os
app = Flask(__name__)
os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
PERSIST = True
query = None
def main_func(message, history):
global query
chat_history = history
if PERSIST and os.path.exists("persist"):
print("Reusing index...\n")
vectorstore = Chroma(persist_directory="persist", embedding_function=OpenAIEmbeddings())
index = VectorStoreIndexWrapper(vectorstore=vectorstore)
else:
loader = TextLoader("new.txt")
if PERSIST:
index = VectorstoreIndexCreator(vectorstore_kwargs={"persist_directory": "persist"}).from_loaders([loader])
else:
index = VectorstoreIndexCreator().from_loaders([loader])
print(index)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
chain = ConversationalRetrievalChain.from_llm(llm=ChatOpenAI(), retriever=index.vectorstore.as_retriever(), memory=memory, verbose=True)
query = message
result = chain({"question": query, "chat_history": chat_history})
print(result['answer'])
chat_history.append((query, result['answer']))
return result['answer']
@app.route('/generate-text/<input_text>', methods=['POST'])
def generate_text(input_text):
global query
generated_text = main_func(input_text, [])
tts = gTTS(text=generated_text, lang='en')
tts.save("output.mp3")
return jsonify({
'generated_text': generated_text,
'audio_url': request.host_url + 'audio'
})
@app.route('/audio')
def get_audio():
return send_file("output.mp3", as_attachment=True)
if __name__ == "__main__":
app.run(debug=True)