|
import gradio as gr |
|
import os |
|
from huggingface_hub import login |
|
|
|
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer |
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
token = os.environ.get("token") |
|
login(token) |
|
print("login is succesful") |
|
max_length=512 |
|
|
|
MODEL_NAME = "google/flan-t5-base" |
|
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, token=token) |
|
config = PeftConfig.from_pretrained("Orcawise/eu_ai_act_orcawise_july12") |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") |
|
model = PeftModel.from_pretrained(base_model, "Orcawise/eu_ai_act_orcawise_july12") |
|
|
|
|
|
|
|
def generate_text(prompt): |
|
"""Generates text using the PEFT model. |
|
Args: |
|
prompt (str): The user-provided prompt to start the generation. |
|
Returns: |
|
str: The generated text. |
|
""" |
|
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate( |
|
input_ids = inputs["input_ids"], |
|
max_length=max_length, |
|
num_beams=5, |
|
repetition_penalty=1.5, |
|
temperature=1, |
|
top_p=0.5, |
|
early_stopping=True |
|
|
|
) |
|
print("show the output", outputs) |
|
|
|
generated_text = tokenizer.decode(outputs, skip_special_tokens=True)[0] |
|
print("show the generated text", generated_text) |
|
return generated_text |
|
|
|
custom_css=""" |
|
.message.pending { |
|
background: #A8C4D6; |
|
} |
|
/* Response message */ |
|
.message.bot.svelte-1s78gfg.message-bubble-border { |
|
/* background: white; */ |
|
border-color: #266B99 |
|
} |
|
/* User message */ |
|
.message.user.svelte-1s78gfg.message-bubble-border{ |
|
background: #9DDDF9; |
|
border-color: #9DDDF9 |
|
|
|
} |
|
/* For both user and response message as per the document */ |
|
span.md.svelte-8tpqd2.chatbot.prose p { |
|
color: #266B99; |
|
} |
|
/* Chatbot comtainer */ |
|
.gradio-container{ |
|
/* background: #84D5F7 */ |
|
} |
|
/* RED (Hex: #DB1616) for action buttons and links only */ |
|
.clear-btn { |
|
background: #DB1616; |
|
color: white; |
|
} |
|
/* #84D5F7 - Primary colours are set to be used for all sorts */ |
|
.submit-btn { |
|
background: #266B99; |
|
color: white; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox(placeholder="Ask your question...") |
|
submit_button = gr.Button("Submit", elem_classes="submit-btn") |
|
clear = gr.Button("Clear", elem_classes="clear-btn") |
|
|
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
|
|
def bot(history): |
|
history[-1][1] = "" |
|
if len(history) < 0: |
|
bot_message = "Hi there! How can I help you today?" |
|
history.append([None, bot_message]) |
|
for character in bot_message: |
|
history[-1][1] += character |
|
yield history |
|
|
|
else: |
|
previous_message = history[-1][0] |
|
bot_message = generate_text(previous_message) |
|
for character in bot_message: |
|
history[-1][1] += character |
|
yield history |
|
|
|
|
|
|
|
|
|
submit_button.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
|
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.launch() |