#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Feb 13 11:22:52 2024 @author: stinpankajm """ import os import base64 from huggingface_hub import InferenceClient import gradio as gr client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2") # Formats the prompt to hold all of the past messages def format_prompt(message, history): prompt = "" prompt_template = "[INST] {} [/INST]" # Iterates through every past user input and response to be added to the prompt for user_prompt, bot_response in history: prompt += prompt_template.format(user_prompt) prompt += f" {bot_response} " prompt += prompt_template.format(message) return prompt MODEL_PATH = "/home/stinpankajm/workspace/FG_POCs/Insects_Scouting/Models/F_Scout_v0.2" css = """ #warning {background-color: #FFCCCB} #flag {color: red;} #topHeading { padding: 30px 0px 30px 15px; box-shadow: 1px 0px 30px 0px rgba(0, 0, 0, 0.1); } #logoImg { max-width: 260px; } """ # Use for GEC, Doesn't track actual history def format_prompt_finadvisor(message, history): prompt = "" # String to add before every prompt prompt_prefix = """\ You are an agriculture expert providing advice to farmers and users. Your task is to answer questions related to agriculture based on the Question provided below. Do not provide any explanations and respond only with medium short answers, add bullet points whenever necessary.. Your TEXT to analyze: """ prompt_template = "[INST] " + prompt_prefix + ' {} [/INST]' # Iterates through every past user input and response to be added to the prompt for user_prompt, bot_response in history: prompt += prompt_template.format(user_prompt) prompt += f" {bot_response} \n" prompt += prompt_template.format(message) return prompt def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,) #formatted_prompt = format_prompt_grammar(f"Corrected Sentence: {prompt}", history) formatted_prompt = format_prompt_finadvisor(f"{system_prompt} {prompt}", history) # print("\nPROMPT: \n\t" + formatted_prompt) # Generate text from the HF inference stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text yield output return output additional_inputs=[ gr.Textbox( label="System Prompt", value="" , max_lines=1, interactive=True, ), gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] with gr.Blocks(css=css) as demo: """ Top Custom Header """ with gr.Row(elem_id="topHeading"): # with gr.Column(elem_id="logoImg"): # with open('./static/logo.jpeg', "rb") as image: # encoded = base64.b64encode(image.read()).decode() # logo_image = f"data:image/png;base64,{encoded}" # gr.HTML(f'') # gr.Image(Image.open('./static/FarmGyan logo_1.png')) with gr.Column(): gr.Markdown( """ # FinAdvisor """, ) """ Model Prediction """ gr.ChatInterface( fn=generate, chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), additional_inputs=additional_inputs, title="AgExpert", examples=[], ).queue().launch() # ).queue().launch(auth=("shivraiAdmin", "FarmERP@2024"), auth_message="Please enter your credentials to get started.") # demo.launch(show_api=False)