|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from transformers import pipeline |
|
import re |
|
|
|
|
|
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3") |
|
|
|
|
|
ner_pipeline = pipeline("ner", model="d4data/biomedical-ner-all") |
|
|
|
|
|
def extract_entities(text): |
|
entities = ner_pipeline(text) |
|
|
|
merged_entities = [] |
|
current_word = "" |
|
current_entity = None |
|
|
|
for ent in entities: |
|
word = ent["word"] |
|
entity_type = ent["entity"].split("-")[-1] |
|
|
|
|
|
if word.startswith("#"): |
|
current_word += word.lstrip("#") |
|
else: |
|
if current_word and current_entity: |
|
merged_entities.append({"word": current_word, "entity": current_entity}) |
|
|
|
current_word = word |
|
current_entity = entity_type |
|
|
|
if current_word and current_entity: |
|
merged_entities.append({"word": current_word, "entity": current_entity}) |
|
|
|
|
|
return merged_entities |
|
|
|
|
|
def highlight_text(text, entities): |
|
words = text.split(' ') |
|
entity_words = {ent["word"].lower(): ent for ent in entities} |
|
|
|
for i, word in enumerate(words): |
|
clean_word = word.strip('.,!?()[]') |
|
lower_word = clean_word.lower() |
|
|
|
if lower_word in entity_words: |
|
words[i] = f"<span style='background-color: #ffcc80; color: black; padding: 2px; border-radius: 4px;'>{word}</span>" |
|
|
|
highlighted_text = ' '.join(words) |
|
|
|
|
|
if entities: |
|
entity_list = "<h4>π Recognized Medical Entities:</h4><ul>" |
|
for ent in entities: |
|
entity_list += f"<li><strong>{ent['word']}</strong> ({ent['entity']})</li>" |
|
entity_list += "</ul>" |
|
else: |
|
entity_list = "<p><em>No medical entities detected.</em></p>" |
|
|
|
return highlighted_text + "<br><br>" + entity_list |
|
|
|
|
|
|
|
def chat_with_ner(message, history): |
|
entities = extract_entities(message) |
|
recognized_entities = [ent["word"] for ent in entities] |
|
|
|
if recognized_entities: |
|
prompt = f"This text contains medical terms: {', '.join(recognized_entities)}. Please explain briefly." |
|
else: |
|
prompt = message |
|
show_to_history = f"Medical Object Recognized : {', '.join(recognized_entities)}. Here are the informations about the recognized medical object." |
|
response = client.text_generation(prompt, max_new_tokens=100) |
|
highlighted_message = highlight_text(message, entities) |
|
|
|
history.append((show_to_history, response)) |
|
|
|
return history, highlighted_message |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"<h1 style='text-align: center;'>π Mistral AI Chatbot</h1>" |
|
"<p style='text-align: center;'>π¬ Chat with Mistral-7B and experience advanced AI conversations!</p>", |
|
) |
|
|
|
chatbot = gr.Chatbot(label="Mistral AI Assistant") |
|
message = gr.Textbox(placeholder="Type your message here...", label="Your Message") |
|
highlighted_output = gr.HTML(label="Highlighted Text (NER)") |
|
send_btn = gr.Button("Send π") |
|
|
|
def respond(user_input, chat_history): |
|
return chat_with_ner(user_input, chat_history) |
|
|
|
send_btn.click(respond, inputs=[message, chatbot], outputs=[chatbot, highlighted_output]) |
|
|
|
|
|
demo.launch() |
|
|