File size: 5,951 Bytes
363236f
 
b6c25db
 
 
363236f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ddd98
363236f
b6c25db
 
 
363236f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ddd98
 
363236f
c8ddd98
363236f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c25db
363236f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ddd98
363236f
 
c8ddd98
363236f
 
 
 
 
 
 
 
 
 
c8ddd98
 
363236f
 
 
 
c8ddd98
363236f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import argparse
import re

import streamlit as st
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import tokenizers

random.seed(None)
suggested_text_list = ['ื”ืฉื“ ื”ื•ืคื™ืข ืžื•ืœ','ืงืืœื™ ืฉืœืคื” ืืช','ืคืขื ืื—ืช ืœืคื ื™ ืฉื ื™ื ืจื‘ื•ืช', 'ื”ืืจื™ ืคื•ื˜ืจ ื—ื™ื™ืš ื—ื™ื•ืš ื ื‘ื•ืš', 'ื•ืื– ื”ืคืจืชื™ ืืช ื›ืœ ื›ืœืœื™ ื”ื˜ืงืก ื›ืฉ']

@st.cache(hash_funcs={tokenizers.Tokenizer: id, tokenizers.AddedToken: id})
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def extend(input_text, max_size=20, top_k=50, top_p=0.95, temperature=0.7):
    if len(input_text) == 0:
        input_text = "<|startoftext|>"
    else:
        input_text = "<|startoftext|>" + input_text

    encoded_prompt = tokenizer.encode(
    input_text, add_special_tokens=False, return_tensors="pt")

    encoded_prompt = encoded_prompt.to(device)

    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    
    output_sequences = model.generate(
    input_ids=input_ids,
    max_length=max_size + len(encoded_prompt[0]),
    top_k=top_k, 
    top_p=top_p,
    temperature=temperature,
    do_sample=True,
    repetition_penalty=2.0,
    num_return_sequences=1)

    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):        
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Remove all text after the stop token
        text = text[: text.find(stop_token) if stop_token else None]

        # Remove all text after 3 newlines
        text = text[: text.find(new_lines) if new_lines else None]

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            input_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )

        generated_sequences.append(total_sequence)
    
    parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n")
    if len(parsed_text) == 0:
        parsed_text = "ืฉื’ื™ืื”"
    return parsed_text



if __name__ == "__main__":
    st.title("Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)")
    model, tokenizer = load_model("./model")
    
    stop_token = "<|endoftext|>"
    new_lines = "<|pad|>"

    np.random.seed(None)
    random_seed = np.random.randint(10000,size=1)    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()

    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)

    model.to(device)

    text_area = st.text_area("Enter the first few words (or leave blank), tap on \"Generate Text\" below. Tapping again will produce a different result.", 'ื”ืื“ื ื”ืื—ืจื•ืŸ ืขืœื™ ืื“ืžื•ืช ื™ืฉื‘ ืœื‘ื“ ื‘ื—ื“ืจื• ื›ืฉืœืคืชืข ื ืฉืžืขื” ื“ืคื™ืงื”')

    st.sidebar.subheader("Configurable parameters")

    max_len = st.sidebar.slider("Max-Length", 0, 256, 160,help="The maximum length of the sequence to be generated.")
    top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
    top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
    temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 1.0, help="The value used to module the randomness of the output tokens.")

    if st.button("Generate Text"):
        with st.spinner(text="Generating results..."):
            st.subheader("Result")
            print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
            if len(text_area.strip()) == 0:
                text_area = random.choice(suggested_text_list)
            result = extend(input_text=text_area,                         
                            max_size=int(max_len),                         
                            top_k=int(top_k),
                            top_p=float(top_p),
                            temperature=float(temperature))

            print("Done length: " + str(len(result)) + " bytes") 
            #<div class="rtl" dir="rtl" style="text-align:right;">
            st.markdown(f"<p dir=\"rtl\" style=\"text-align:right;\"> {result} </p>", unsafe_allow_html=True)
            st.write("\n\nResult length: " + str(len(result)) + " bytes\n Random seed: " + str(random_seed) + "\ntop_k: " + str(top_k) + "\ntop_p: " + str(top_p) + "\ntemperature: " + str(temperature) + "\nmax_len: " + str(max_len) + "\ndevice: " + str(device) + "\nn_gpu: " + str(n_gpu))
            print(f"\"{result}\"")      
    
    st.markdown(    
        """Hebrew text generation model based on EleutherAI's gpt-neo architecture. Originally trained on a TPUv3-8 which was made avilable to me via the [TPU Research Cloud Program](https://sites.research.google/trc/). The model was then slightly fine-tuned upon science fiction and fantasy text."""
    )

    st.markdown("<footer><hr><p style=\"font-size:14px\">The site is fan made and is not affiliated with any author in any way.</p><p style=\"font-size:12px\">By <a href=\"https://linktr.ee/Norod78\">Doron Adler</a></p></footer> ", unsafe_allow_html=True)