import time from huggingface_hub import InferenceClient import gradio as gr # Initialize the inference client with the new LLM client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") # Define the system prompt for enhancing user prompts SYSTEM_PROMPT = ( "You are a prompt enhancer and your work is to enhance the given prompt under 100 words " "without changing the essence, only write the enhanced prompt and nothing else." ) def format_prompt(message): """ Format the input message using the system prompt and a timestamp to ensure uniqueness. """ timestamp = time.time() formatted = ( f"[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]" f"[INST] {message} {timestamp} [/INST]" ) return formatted def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0): """ Generate an enhanced prompt using the new LLM. This function yields intermediate results as they are generated. """ temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = { "temperature": temperature, "max_new_tokens": int(max_new_tokens), "top_p": top_p, "repetition_penalty": float(repetition_penalty), "do_sample": True, } formatted_prompt = format_prompt(message) stream = client.text_generation( formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False, ) output = "" for response in stream: token_text = response.token.text output += token_text yield output.strip('') return output.strip('') # Markdown texts for credits and best practices CREDITS_MARKDOWN = """ # Prompt Enhancer Credits: Instructions and design inspired by [ruslanmv.com](https://ruslanmv.com). """ BEST_PRACTICES = """ **Best Practices** - Be specific and clear in your input prompt - Use temperature 0.0 for consistent, focused results - Increase temperature up to 1.0 for more creative variations - Review and iterate on engineered prompts for optimal results """ # Build the Gradio interface with the Ocean theme with gr.Blocks(theme=gr.themes.Ocean(), css=".gradio-container { max-width: 800px; margin: auto; }") as demo: # Credits at the top gr.Markdown(CREDITS_MARKDOWN) gr.Markdown( "Enhance your prompt to under 100 words while preserving its essence. " "Adjust the generation parameters as needed." ) with gr.Row(): with gr.Column(scale=1): input_prompt = gr.Textbox( label="Input Prompt", placeholder="Enter your prompt here...", lines=4, ) max_tokens_slider = gr.Slider( label="Max New Tokens", minimum=50, maximum=512, step=1, value=256, ) temperature_slider = gr.Slider( label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.9, ) top_p_slider = gr.Slider( label="Top-p (nucleus sampling)", minimum=0.1, maximum=1.0, step=0.05, value=0.95, ) repetition_penalty_slider = gr.Slider( label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0, ) generate_button = gr.Button("Enhance Prompt") with gr.Column(scale=1): output_prompt = gr.Textbox( label="Enhanced Prompt", lines=10, interactive=True, ) # Best practices message at the bottom gr.Markdown(BEST_PRACTICES) # Wire the button click to the generate function (streaming functionality is handled internally) generate_button.click( fn=generate, inputs=[ input_prompt, max_tokens_slider, temperature_slider, top_p_slider, repetition_penalty_slider, ], outputs=output_prompt, ) demo.launch()