File size: 2,690 Bytes
c0a205f
 
 
 
4448a6c
c0a205f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4448a6c
c0a205f
 
 
4448a6c
c0a205f
 
 
 
 
 
 
 
 
 
0dd0b82
c0a205f
 
4448a6c
0dd0b82
 
4448a6c
0dd0b82
 
 
4448a6c
c0a205f
 
0dd0b82
c0a205f
 
 
 
 
 
 
 
cf96a2d
ab558ad
4646178
 
c0a205f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer

model_id = "rasyosef/gpt2-medium-amharic-28k-512"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

gpt2_am = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id
  )

def generate(prompt):
  prompt_length = len(tokenizer.tokenize(prompt))
  if prompt_length >= 128:
    yield prompt + "\n\nPrompt is too long. It needs to be less than 128 tokens."
  else:
    max_new_tokens = max(0, 128 - prompt_length)
    streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=300.0)
    thread = Thread(
      target=gpt2_am,
      kwargs={
        "text_inputs": prompt,
        "max_new_tokens": max_new_tokens,
        "temperature": 0.4,
        "do_sample": True,
        "top_k": 8,
        "top_p": 0.8,
        "repetition_penalty": 1.4,
        "streamer": streamer
      })
    thread.start()

    generated_text = ""
    for word in streamer:
      generated_text += word
      response = generated_text.strip()
      yield response

with gr.Blocks(css="#prompt_textbox textarea {color: blue}") as demo:
  gr.Markdown("""
  # GPT2 Amharic
  This is a demo for a smaller version of OpenAI's [gpt2](https://huggingface.co/openai-community/gpt2) decoder transformer model pretrained for 2 days on `290 million` tokens of **Amharic** text. The context size of [gpt2-medium-amharic](https://huggingface.co/rasyosef/gpt2-medium-amharic-28k-512) is 512 tokens. This is a base model and hasn't undergone any supervised finetuing yet.
  Please **enter a prompt** and click the **Generate** button to generate completions for the prompt.
  #### Text generation parameters:
  - `temperature` : **0.4**
  - `do_sample` : **True**
  - `top_k` : **8**
  - `top_p` : **0.8**
  - `repetition_penalty` : **1.4**
  """)

  prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", lines=4, interactive=True, elem_id="prompt_textbox")
  with gr.Row():
    with gr.Column():
      gen = gr.Button("Generate")
    with gr.Column():
      btn = gr.ClearButton([prompt])
  gen.click(generate, inputs=[prompt], outputs=[prompt])
  examples = gr.Examples(
        examples=[
            "አዲስ አበባ",
            "በእንግሊዙ ፕሬሚየር ሊግ",
            "ፕሬዚዳንት ዶናልድ ትራምፕ",
            "በመስቀል አደባባይ"
          ],
        inputs=[prompt],
    )
demo.queue().launch(debug=True)