Spaces:
Sleeping
Sleeping
File size: 4,628 Bytes
33f66f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 13 11:22:52 2024
@author: stinpankajm
"""
import os
import base64
from huggingface_hub import InferenceClient
import gradio as gr
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
# Formats the prompt to hold all of the past messages
def format_prompt(message, history):
prompt = "<s>"
prompt_template = "[INST] {} [/INST]"
# Iterates through every past user input and response to be added to the prompt
for user_prompt, bot_response in history:
prompt += prompt_template.format(user_prompt)
prompt += f" {bot_response}</s> "
prompt += prompt_template.format(message)
return prompt
MODEL_PATH = "/home/stinpankajm/workspace/FG_POCs/Insects_Scouting/Models/F_Scout_v0.2"
css = """
#warning {background-color: #FFCCCB}
#flag {color: red;}
#topHeading {
padding: 30px 0px 30px 15px;
box-shadow: 1px 0px 30px 0px rgba(0, 0, 0, 0.1);
}
#logoImg {
max-width: 260px;
}
"""
# Use for GEC, Doesn't track actual history
def format_prompt_finadvisor(message, history):
prompt = "<s>"
# String to add before every prompt
prompt_prefix = """\
You are an agriculture expert providing advice to farmers and users.
Your task is to answer questions related to agriculture based on the Question provided below.
Do not provide any explanations and respond only with medium short answers, add bullet points whenever necessary..
Your TEXT to analyze:
"""
prompt_template = "[INST] " + prompt_prefix + ' {} [/INST]'
# Iterates through every past user input and response to be added to the prompt
for user_prompt, bot_response in history:
prompt += prompt_template.format(user_prompt)
prompt += f" {bot_response}</s> \n"
prompt += prompt_template.format(message)
return prompt
def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2: temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)
#formatted_prompt = format_prompt_grammar(f"Corrected Sentence: {prompt}", history)
formatted_prompt = format_prompt_finadvisor(f"{system_prompt} {prompt}", history)
# print("\nPROMPT: \n\t" + formatted_prompt)
# Generate text from the HF inference
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Textbox( label="System Prompt", value="" , max_lines=1, interactive=True, ),
gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]
with gr.Blocks(css=css) as demo:
"""
Top Custom Header
"""
with gr.Row(elem_id="topHeading"):
# with gr.Column(elem_id="logoImg"):
# with open('./static/logo.jpeg', "rb") as image:
# encoded = base64.b64encode(image.read()).decode()
# logo_image = f"data:image/png;base64,{encoded}"
# gr.HTML(f'<img src={logo_image} style="width:155px">')
# gr.Image(Image.open('./static/FarmGyan logo_1.png'))
with gr.Column():
gr.Markdown(
"""
# FinAdvisor
""",
)
"""
Model Prediction
"""
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="AgExpert",
examples=[],
).queue().launch()
# ).queue().launch(auth=("shivraiAdmin", "FarmERP@2024"), auth_message="Please enter your credentials to get started.")
# demo.launch(show_api=False)
|