Spaces:
Running
Running
File size: 2,057 Bytes
fa9b36e b420929 5135364 b420929 fa9b36e b420929 fa9b36e b420929 bd2a222 b420929 5eb3a8a b420929 d9a39ef fa9b36e 0116f84 79f1518 457981c 79f1518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
import logging
import os
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
CACHE_DIR = os.getcwd() + '/cache'
MODEL_NAME = 'grandestroyer/joefreaks'
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, resume_download=None, cache_dir=CACHE_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, resume_download=None, cache_dir=CACHE_DIR)
tokenizer.pad_token = tokenizer.eos_token
def generate_from_full_params(prompt='', n=1, temp=0.7, top_p=0.9, top_k=40, max_length=500, exclude_repetitions=False):
no_repeat_ngram_size = 2 if exclude_repetitions else 0
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
print('Generating with params prompt="{}", n={}, temp={}, top_p={}, top_k={}, max_length={}, no_repeat_ngram_size={}'
.format(prompt, n, temp_normalized, top_p, top_k, max_length, no_repeat_ngram_size))
input_ids = tokenizer.encode(prompt, return_tensors='pt') if prompt != '' else None
generated = model.generate(
input_ids=input_ids,
top_p=top_p,
top_k=top_k,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=temp_normalized,
num_return_sequences=n,
max_length=max_length
)
decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
decoded = list(map(str.strip, decoded))
print(decoded)
return decoded
def generate_one(temp, prompt, exclude_repetitions):
temp_normalized = 0.1 if temp <= 0 else (1.9 if temp >= 2 else temp)
return generate(n=1, temp=temp_normalized, prompt=prompt, exclude_repetitions=exclude_repetitions)[0]
gr_interface = gr.Interface(
fn = generate_from_full_params,
inputs = ['text', gr.Number(precision=0), 'number', 'number', gr.Number(precision=0), gr.Number(precision=0), 'checkbox'],
outputs = 'json'
)
gr_interface.launch()
|