# %% from threading import Thread import gradio as gr # import torch from text_generation import Client, InferenceAPIClient from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer) client = Client("http://20.83.177.108:8080") # text = "" # for response in client.generate_stream("What is Deep Learning?", max_new_tokens=20): # if not response.token.special: # text += response.token.text # print(text) def run_generation(user_text, top_p, temperature, top_k, max_new_tokens): # Get the model and tokenizer, and tokenize the user text. user_text = f"""You are an expert legal assistant with extensive knowledge about Indian law. Your task is to respond to the given query in a consice and factually correct manner. Also mention the relevant sections of the law wherever applicable. ### Input: {user_text} ### Response: """ text = "" for response in client.generate_stream(user_text, max_new_tokens=max_new_tokens,repetition_penalty=1.05): if not response.token.special: text += response.token.text yield text return text def reset_textbox(): return gr.update(value='') with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=4): user_text = gr.Textbox( placeholder="What is the punishment for taking dowry. explain in detail.", label="Question" ) model_output = gr.Textbox(label="AI Response", lines=10, interactive=False) button_submit = gr.Button(value="Submit") with gr.Column(scale=1): max_new_tokens = gr.Slider( minimum=1, maximum=1000, value=250, step=10, interactive=True, label="Max New Tokens", ) top_p = gr.Slider( minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)", ) top_k = gr.Slider( minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k", ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.8, step=0.1, interactive=True, label="Temperature", ) user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output) button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output) demo.queue(max_size=32).launch(enable_queue=True,share=True)