File size: 4,628 Bytes
33f66f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 13 11:22:52 2024

@author: stinpankajm
"""
import os
import base64
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")

# Formats the prompt to hold all of the past messages
def format_prompt(message, history):
    prompt = "<s>"
    prompt_template = "[INST] {} [/INST]"

    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> "
        
    prompt += prompt_template.format(message)
    return prompt

MODEL_PATH = "/home/stinpankajm/workspace/FG_POCs/Insects_Scouting/Models/F_Scout_v0.2"


css = """
#warning {background-color: #FFCCCB} 
#flag {color: red;}
#topHeading {
    padding: 30px 0px 30px 15px;
    box-shadow: 1px 0px 30px 0px rgba(0, 0, 0, 0.1);
}
#logoImg {
    max-width: 260px;
}
"""


# Use for GEC, Doesn't track actual history
def format_prompt_finadvisor(message, history):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = """\
        You are an agriculture expert providing advice to farmers and users.

        Your task is to answer questions related to agriculture based on the Question provided below.

        Do not provide any explanations and respond only with medium short answers, add bullet points whenever necessary..

        Your TEXT to analyze:
        """
    
    prompt_template = "[INST] " + prompt_prefix + ' {} [/INST]'


    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> \n"

    prompt += prompt_template.format(message)
    return prompt


def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2: temperature = 1e-2
    top_p = float(top_p)
    
    generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)
    
    #formatted_prompt = format_prompt_grammar(f"Corrected Sentence: {prompt}", history)
    formatted_prompt = format_prompt_finadvisor(f"{system_prompt} {prompt}", history)
    # print("\nPROMPT: \n\t" + formatted_prompt)

    # Generate text from the HF inference
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output

additional_inputs=[
    gr.Textbox( label="System Prompt", value="" , max_lines=1, interactive=True, ),
    gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
    gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
    gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
    gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]

with gr.Blocks(css=css) as demo:
    """
        Top Custom Header 
    """
    with gr.Row(elem_id="topHeading"):
        # with gr.Column(elem_id="logoImg"):
        #     with open('./static/logo.jpeg', "rb") as image:
        #         encoded = base64.b64encode(image.read()).decode()
        #         logo_image = f"data:image/png;base64,{encoded}"

        #         gr.HTML(f'<img src={logo_image} style="width:155px">')

            # gr.Image(Image.open('./static/FarmGyan logo_1.png'))
        with gr.Column():
            gr.Markdown(
                """
            # FinAdvisor
            """,
            )

    """
        Model Prediction
    """
    


gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="AgExpert",
    examples=[],
).queue().launch()
# ).queue().launch(auth=("shivraiAdmin", "FarmERP@2024"), auth_message="Please enter your credentials to get started.")
# demo.launch(show_api=False)