HaS-820m / hideAndSeek.py
tingxinli's picture
Update hideAndSeek.py
05c0902
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import openai
from openai import OpenAI
def hide(original_input, hide_model, tokenizer):
hide_template = """<s>Paraphrase the text:%s\n\n"""
input_text = hide_template % original_input
inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device)
pred = hide_model.generate(
**inputs,
generation_config=GenerationConfig(
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
do_sample=False,
num_beams=3,
repetition_penalty=5.0,
),
)
pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
hide_input = tokenizer.decode(pred, skip_special_tokens=True)
return hide_input
def seek(hide_input, hide_output, original_input, seek_model, tokenizer):
seek_template = """<s>Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n"""
input_text = seek_template % (hide_input, hide_output, original_input)
inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device)
pred = seek_model.generate(
**inputs,
generation_config=GenerationConfig(
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
do_sample=False,
num_beams=3,
),
)
pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
original_output = tokenizer.decode(pred, skip_special_tokens=True)
return original_output
def get_gpt_output(prompt, api_key=None):
if not api_key:
raise ValueError('an open api key is needed for this function')
client = OpenAI(api_key=api_key)
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content