kaktuspassion commited on
Commit
81c89b2
1 Parent(s): cb9de27

initial commit

Browse files
Files changed (2) hide show
  1. app.py +44 -0
  2. requirements.txt +76 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Define the model loading function
5
+ def load_model(model_name):
6
+ # Load the text generation pipeline
7
+ generator = pipeline('text-generation', model=model_name)
8
+ return generator
9
+
10
+ # Define the text generation function
11
+ def generate_text(model_name, prompt, custom_prompt, temperature, max_length, top_p, beam_size, frequency_penalty, presence_penalty):
12
+ if temperature == 0:
13
+ temperature = 0.0001
14
+ do_sample = False
15
+ else:
16
+ do_sample = True
17
+ generator = load_model(model_name)
18
+ if custom_prompt:
19
+ prompt = custom_prompt
20
+ generate_text = generator(prompt, temperature=float(temperature), max_length=max_length, top_p=top_p, num_beams=beam_size, truncation=True)
21
+ return generate_text[0]['generated_text']
22
+
23
+ # Pre-written prompts
24
+ prompts = ["Write a tagline for an ice cream shop", "Describe the Word War II", "Write a short story about a robot", "Explain the concept of gravity"]
25
+
26
+ # Interface
27
+ demo = gr.Interface(
28
+ fn=generate_text,
29
+ inputs=[
30
+ gr.Radio(choices=["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"], label="Model", value="gpt2", info="Choose the size of the model to use."),
31
+ gr.Dropdown(choices=prompts, label="Prompt", info="Select a pre-written prompt."),
32
+ gr.Textbox(label="Custom Prompt", placeholder="Or write your own prompt here", lines=5),
33
+ gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Temperature", info="Controls randomness: Higher values make the output more random, while lower values make the output more deterministic and repetitive."),
34
+ gr.Slider(minimum=1, maximum=256, value=16, label="Maximum Length", info="The maximum number of tokens to generate shared between the prompt and the completion."),
35
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Top P", info="Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered."),
36
+ gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Beam Size", info="Number of beams to use for beam search. 1 means Greedy decoding."),
37
+
38
+ ],
39
+ outputs=["text"],
40
+ title="GPT-2 playground Mockup",
41
+ description="Adjust the sliders and enter a prompt to generate text using GPT-2."
42
+ )
43
+
44
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.6.0
4
+ anyio==4.3.0
5
+ attrs==23.2.0
6
+ certifi==2024.2.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.2.1
10
+ cycler==0.12.1
11
+ exceptiongroup==1.2.1
12
+ fastapi==0.110.2
13
+ ffmpy==0.3.2
14
+ filelock==3.13.4
15
+ fonttools==4.51.0
16
+ fsspec==2024.3.1
17
+ gradio==4.27.0
18
+ gradio_client==0.15.1
19
+ h11==0.14.0
20
+ httpcore==1.0.5
21
+ httpx==0.27.0
22
+ huggingface-hub==0.22.2
23
+ idna==3.7
24
+ importlib_resources==6.4.0
25
+ Jinja2==3.1.3
26
+ jsonschema==4.21.1
27
+ jsonschema-specifications==2023.12.1
28
+ kiwisolver==1.4.5
29
+ markdown-it-py==3.0.0
30
+ MarkupSafe==2.1.5
31
+ matplotlib==3.8.4
32
+ mdurl==0.1.2
33
+ mpmath==1.3.0
34
+ networkx==3.2.1
35
+ numpy==1.26.4
36
+ orjson==3.10.1
37
+ packaging==24.0
38
+ pandas==2.2.2
39
+ pillow==10.3.0
40
+ pydantic==2.7.1
41
+ pydantic_core==2.18.2
42
+ pydub==0.25.1
43
+ Pygments==2.17.2
44
+ pyparsing==3.1.2
45
+ python-dateutil==2.9.0.post0
46
+ python-multipart==0.0.9
47
+ pytz==2024.1
48
+ PyYAML==6.0.1
49
+ referencing==0.35.0
50
+ regex==2024.4.16
51
+ requests==2.31.0
52
+ rich==13.7.1
53
+ rpds-py==0.18.0
54
+ ruff==0.4.1
55
+ safetensors==0.4.3
56
+ semantic-version==2.10.0
57
+ shellingham==1.5.4
58
+ six==1.16.0
59
+ sniffio==1.3.1
60
+ starlette==0.37.2
61
+ sympy==1.12
62
+ tokenizers==0.19.1
63
+ tomlkit==0.12.0
64
+ toolz==0.12.1
65
+ torch==2.2.2
66
+ torchaudio==2.2.2
67
+ torchvision==0.17.2
68
+ tqdm==4.66.2
69
+ transformers==4.40.0
70
+ typer==0.12.3
71
+ typing_extensions==4.11.0
72
+ tzdata==2024.1
73
+ urllib3==2.2.1
74
+ uvicorn==0.29.0
75
+ websockets==11.0.3
76
+ zipp==3.18.1