Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 3,156 Bytes
			
			| 77e720d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
elif (
    hasattr(torch.backends, "mps")
    and torch.backends.mps.is_available()
    and torch.backends.mps.is_built()
):
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"running device: {device}")
auth_token = os.environ.get("TOKEN_READ_SECRET") or True
tokenizer = AutoTokenizer.from_pretrained(
    "NorHsangPha/shan_gpt2_news", token=auth_token
)
model = AutoModelForCausalLM.from_pretrained(
    "NorHsangPha/shan_gpt2_news", pad_token_id=tokenizer.eos_token_id, token=auth_token
).to(device)
def greedy_search(model_inputs, max_new_tokens):
    greedy_output = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
    return tokenizer.decode(greedy_output[0], skip_special_tokens=True)
def beem_search(model_inputs, max_new_tokens):
    beam_output = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        num_beams=5,
        no_repeat_ngram_size=2,  #
        num_return_sequences=5,  #
        early_stopping=True,
    )
    return tokenizer.decode(beam_output[0], skip_special_tokens=True)
def sample_outputs(model_inputs, max_new_tokens):
    sample_output = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_k=0,
        temperature=0.6,
    )
    return tokenizer.decode(sample_output[0], skip_special_tokens=True)
def top_k_search(model_inputs, max_new_tokens):
    top_k_output = model.generate(
        **model_inputs, max_new_tokens=max_new_tokens, do_sample=True, top_k=50
    )
    return tokenizer.decode(top_k_output[0], skip_special_tokens=True)
def top_p_search(model_inputs, max_new_tokens):
    top_p_output = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=0.92,
        top_k=0,
    )
    return tokenizer.decode(top_p_output[0], skip_special_tokens=True)
def generate_text(input_text, search_method="sample_outputs"):
    model_inputs = tokenizer(input_text, return_tensors="pt").to(device)
    max_new_tokens = 120
    match search_method:
        case "greedy_search":
            text = greedy_search(model_inputs, max_new_tokens)
        case "beem_search":
            text = beem_search(model_inputs, max_new_tokens)
        case "top_k_search":
            text = top_k_search(model_inputs, max_new_tokens)
        case "top_p_search":
            text = top_p_search(model_inputs, max_new_tokens)
        case _:
            text = sample_outputs(model_inputs, max_new_tokens)
    return text
GENERATE_EXAMPLES = [
    ["αααΊααα―ααΊαΆαα", "sample_outputs"],
    ["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "greedy_search"],
    ["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "top_k_search"],
    ["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "top_p_search"],
    ["αα’ααΊααα―α΅αΊαΈααα―α΅αΊαΈααα°ααΊ", "beem_search"],
]
 | 
