File size: 1,104 Bytes
e925e63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from langchain_nvidia_ai_endpoints._common import NVEModel
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
import gradio as gr
import requests
import os


os.environ["NVIDIA_API_KEY"] = "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"
print(f"Retrieved NVIDIA_API_KEY beginning with \"{os.environ.get('NVIDIA_API_KEY')[:9]}...\"")


## Chat Pipeline
inst_llm = ChatNVIDIA(model="nv_llama2_rlhf_70b")  ## Models

prompt = ChatPromptTemplate.from_messages([
    ("system", "Friendly chatbot that helps people with their problems."),
    ("user", "previous conversations: {previous}. Input : {input}")
])

chain = prompt | inst_llm | StrOutputParser()  ## expects input and context

def chat_stream(message, history = []):
    buffer = ""
    for token in chain.stream({"input" : message, "previous" : history}):
        buffer += token
        yield buffer
    history.append([message, buffer])


gr.ChatInterface(chat_stream).queue().launch(debug=True, share=True)