File size: 2,733 Bytes
8f835ef
 
90f787c
8f835ef
 
e2275e6
cacb7f3
47f7c8e
ac02bee
 
 
 
 
 
8f835ef
 
 
23fd6a8
8f835ef
 
 
cacb7f3
8f835ef
 
6d50512
23fd6a8
6d50512
23fd6a8
 
ac02bee
 
 
 
 
23fd6a8
 
ac02bee
6d50512
 
ac02bee
6d50512
 
8f835ef
 
ac02bee
6d50512
23fd6a8
6d50512
ac02bee
6d50512
ac02bee
23fd6a8
 
ac02bee
 
 
6d50512
ac02bee
6d50512
23fd6a8
cacb7f3
8f835ef
23fd6a8
ac02bee
8f835ef
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import time
import gradio as gr
import os
import asyncio
from pymongo import MongoClient
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_openai import OpenAIEmbeddings
from langchain_community.llms import OpenAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

output_parser = StrOutputParser()

import json


## Connect to MongoDB Atlas 
MONGODB_ATLAS_CLUSTER_URI = os.getenv('MONGODB_ATLAS_CLUSTER_URI')
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
db_name = 'sample_mflix'
collection_name = 'embedded_movies'
collection = client[db_name][collection_name]

try:
    ## Vector store init
    vector_store = MongoDBAtlasVectorSearch(embedding=OpenAIEmbeddings(), collection=collection, index_name='vector_index', text_key='plot', embedding_key='plot_embedding')

    ## LLM init
    llm = ChatOpenAI(temperature=0)
    prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a movie recommendation engine which post a concise and short summary on relevant movies."),
    ("user", "List of movies: {input}")
     ])

    ## RAG Chain
    chain = prompt | llm | output_parser

except:
    #If open ai key is wrong
    print ('Open AI key is wrong')
    vector_store = None

def get_movies(message, history):

    try:
        ### Get top 3 picks
        movies =  vector_store.similarity_search(message, 3)
        return_text = ''
        for movie in movies:
            return_text = return_text + 'Title : ' +  movie.metadata['title'] + '\n------------\n' + 'Plot: ' + movie.page_content + '\n\n'

        ## Invoke RAG on the located documents
        print_llm_text = chain.invoke({"input": return_text})
    
        for i in range(len(print_llm_text)):
            time.sleep(0.05)
            yield "Found: " + "\n\n" + print_llm_text[: i+1]
    except:
        yield "Please clone the space and add your open ai key as well as your MongoDB Atlas URI in the Secret Section of you Space\n OPENAI_API_KEY (your Open AI key) and MONGODB_ATLAS_CLUSTER_URI (0.0.0.0/0 whitelisted instance with Vector index created) \n\n For more information : https://mongodb.com/products/platform/atlas-vector-search"
    

## Start gradio chat interface
demo = gr.ChatInterface(get_movies, examples=["What movies are scary?", "Find me a comedy", "Movies for kids"], title="Movies Atlas Vector Search",description="This small chat uses a similarity search to find relevant movies, it uses an MongoDB Atlase Vector Search read more here: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-tutorial",submit_btn="Search").queue()

if __name__ == "__main__":
    demo.launch()