File size: 2,586 Bytes
1f8ce60
 
 
293fd6c
14b2c99
1f8ce60
 
 
14b2c99
 
1f8ce60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b2c99
293fd6c
 
14b2c99
 
 
 
 
 
1f8ce60
 
 
 
 
14b2c99
1f8ce60
 
14b2c99
 
 
 
 
 
 
 
 
 
 
1f8ce60
 
 
 
 
 
 
 
 
 
e570bec
 
 
1f8ce60
 
 
 
e570bec
1f8ce60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import requests
import json
import os
from .moderation import check_moderation_text
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()

client = OpenAI()

def auto_suggest_normalize(text):
    try:
        # make post request
        texts = text.split('\n')
        for text in texts:
            text = text.strip()
            if text == "": continue
            yield text
    except Exception as e:
        print(e)
        return texts

headers = {"Authorization": f'Bearer {os.environ["HF_ACCESS_TOKEN"]}'}
def query(payload):
	response = requests.post(os.environ["LLAMA2_INFERENCE_API_URL"], headers=headers, json=payload)
	return response.json()

def query_openai(prompt):
    if check_moderation_text(prompt): return None

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}]
    )
    return response.choices[0].message.content.strip()

def remove_extras(text):
    return text.replace('\\', '').replace('"', '').replace('1.', '').replace('2.', '').replace('3.', '').strip()
    
def auto_suggest_ask_llama2(prompt):
    try:
        answer = query({
            "inputs": prompt,
        })
        print(answer)
        return answer
    except Exception as e:
        print(e)
        return False
    
def auto_suggest_ask_gpt(prompt):
    try:
        output = query_openai(prompt)
        print(output, 'output')
        return output
    except Exception as e:
        print(e)
        return False

def auto_suggest_normalize_llama2(text):
    text_list = []
    try:
        # make post request
        texts = text.split('\n')
        for text in texts:
            if ("1." in text.strip()[:2]) or ("2." in text.strip()[:2]) or ("3." in text.strip()[:2]) or ("4." in text.strip()[:2]) or ("5." in text.strip()[:2]):
                text = remove_extras(text)
                text_list.append(text)
        
    except Exception as e:
        print(e)
      
    return text_list if len(text_list) > 0 else []

def auto_suggest_ask(prompt):
    try:
        # make post request
        response = requests.post('http://localhost:11434/api/generate', json={
            "model": "mistral:v0.2",
            "prompt": prompt
        })

        
        responses = response.text.split('\n')
      
        text = ""
        for response in responses:
            dict = json.loads(response)
            if dict["done"] == True: break
            text += dict["response"]
    
        return text
    except Exception as e:
        print(e)
        return False