File size: 1,334 Bytes
8c6710d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# -*-coding:utf-8 -*-
import json
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains import LLMChain
from prompt import *
from langchain.chat_models import ChatOpenAI


def get_qa(text, openai_key):
    llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.8)
    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(QA_gen_sys_prompt),
            HumanMessagePromptTemplate.from_template(QA_gen_user_prompt),
        ]
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    print('Generating Question from template')
    qa = chain({'text': text})
    result = json.loads(qa['text'])
    return result


def get_answer(context, question, openai_key):
    llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.8)
    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(QA_answer_sys_prompt),
            HumanMessagePromptTemplate.from_template(QA_answer_user_prompt),
        ]
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    print('Generating Question from template')
    answer = chain({'text': context, 'question': question})
    answer = answer['text']
    return answer