File size: 1,898 Bytes
3aaecca 05c0902 3aaecca 05c0902 3aaecca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import openai
from openai import OpenAI
def hide(original_input, hide_model, tokenizer):
hide_template = """<s>Paraphrase the text:%s\n\n"""
input_text = hide_template % original_input
inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device)
pred = hide_model.generate(
**inputs,
generation_config=GenerationConfig(
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
do_sample=False,
num_beams=3,
repetition_penalty=5.0,
),
)
pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
hide_input = tokenizer.decode(pred, skip_special_tokens=True)
return hide_input
def seek(hide_input, hide_output, original_input, seek_model, tokenizer):
seek_template = """<s>Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n"""
input_text = seek_template % (hide_input, hide_output, original_input)
inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device)
pred = seek_model.generate(
**inputs,
generation_config=GenerationConfig(
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
do_sample=False,
num_beams=3,
),
)
pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
original_output = tokenizer.decode(pred, skip_special_tokens=True)
return original_output
def get_gpt_output(prompt, api_key=None):
if not api_key:
raise ValueError('an open api key is needed for this function')
client = OpenAI(api_key=api_key)
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content |