HaS-820m / hideAndSeek.py
tingxinli's picture
Update hideAndSeek.py
05c0902
raw
history blame contribute delete
No virus
1.9 kB
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import openai
from openai import OpenAI
def hide(original_input, hide_model, tokenizer):
hide_template = """<s>Paraphrase the text:%s\n\n"""
input_text = hide_template % original_input
inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device)
pred = hide_model.generate(
**inputs,
generation_config=GenerationConfig(
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
do_sample=False,
num_beams=3,
repetition_penalty=5.0,
),
)
pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
hide_input = tokenizer.decode(pred, skip_special_tokens=True)
return hide_input
def seek(hide_input, hide_output, original_input, seek_model, tokenizer):
seek_template = """<s>Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n"""
input_text = seek_template % (hide_input, hide_output, original_input)
inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device)
pred = seek_model.generate(
**inputs,
generation_config=GenerationConfig(
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
do_sample=False,
num_beams=3,
),
)
pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
original_output = tokenizer.decode(pred, skip_special_tokens=True)
return original_output
def get_gpt_output(prompt, api_key=None):
if not api_key:
raise ValueError('an open api key is needed for this function')
client = OpenAI(api_key=api_key)
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content