File size: 1,898 Bytes
3aaecca
 
 
 
 
 
05c0902
3aaecca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c0902
 
3aaecca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import openai
from openai import OpenAI

def hide(original_input, hide_model, tokenizer):
    hide_template = """<s>Paraphrase the text:%s\n\n"""
    input_text = hide_template % original_input
    inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device)
    pred = hide_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
            do_sample=False,
            num_beams=3,
            repetition_penalty=5.0,
            ),
        )
    pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
    hide_input = tokenizer.decode(pred, skip_special_tokens=True)
    return hide_input

def seek(hide_input, hide_output, original_input, seek_model, tokenizer):
    seek_template = """<s>Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n"""
    input_text = seek_template % (hide_input, hide_output, original_input)
    inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device)
    pred = seek_model.generate(
        **inputs, 
        generation_config=GenerationConfig(
            max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
            do_sample=False,
            num_beams=3,
            ),
        )
    pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
    original_output = tokenizer.decode(pred, skip_special_tokens=True)
    return original_output

def get_gpt_output(prompt, api_key=None):
    if not api_key:
       raise ValueError('an open api key is needed for this function')
    client = OpenAI(api_key=api_key)
    completion = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    return completion.choices[0].message.content