import gradio as gr from pathlib import Path import torch from tsai_gpt.generate_for_app import generate_for_app pythia_model = "checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth" def generate_text(prompt): generated_text = generate_for_app(prompt, num_samples=1, max_new_tokens=200, temperature=0.9, checkpoint_dir=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/")) return generated_text #gr.Interface(fn=generate_nanogpt_text, inputs=gr.Button(value="Generate text!"), outputs='text').launch(share=True) with gr.Blocks() as demo: gr.Markdown( """ # Example of text generation with our pythia 160M model based on the RedPajama sample data: The model checkpoint is the 'checkpoints/meta-llama/Llama-2-7b-chat-hf' dir. The hyper params used are the exact same emitted by the training main.ipynb notebook. The loss is less than 3.5; we can see syntactically correct but semantically meaningless sentences. Keep in mind the output is limited to 250 tokens so the inference runs within reasonable time (10s) on CPU. (Huggingface free tier) GPU inference can output much much longer sequences. Click on the "Generate text" button to see the generated text. """) generate_button = gr.Button("Generate text!") input=gr.Textbox(label="Enter your prompt here") output = gr.Textbox(label="Text generated by Pythia 160M trained model") generate_button.click(fn=generate_text, inputs=input, outputs=output, api_name='text generation sample') demo.launch()