| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache |
| import os |
| import spaces |
|
|
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| MODEL = "pszemraj/medgemma-27b-text-heretic_med" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, local_files_only=True) |
|
|
| |
| |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| |
| offload_folder="./offload_dir", |
| local_files_only=True, |
| ) |
|
|
| if False: |
| def chat_interface(message, history): |
| """ |
| Main chat function to interact with the model. |
| """ |
| chat_history = list(history) |
| |
| |
| chat_history.append({"role": "user", "content": message}) |
| |
| |
| prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True) |
| |
| |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") |
| outputs = model.generate( |
| input_ids.to(device), |
| max_new_tokens=256, |
| do_sample=True, |
| temperature=0.7, |
| top_k=50, |
| top_p=0.95 |
| ) |
| |
| |
| response = tokenizer.decode(outputs[0]) |
| response = response.split("<end_of_turn>")[1].strip() |
| |
| return response |
| |
| |
| gr.ChatInterface( |
| fn=chat_interface, |
| type="messages", |
| title="MedGemma-4B-IT Medical Assistant", |
| description="A fine-tuned model for medical-related questions." |
| ).launch(share=True) |
|
|
| @spaces.GPU(duration=60) |
| def extend(text, max_new_tokens, chunk_size, progress=gr.Progress()): |
| PREFIX = "<bos>\n" |
| progress(0, desc="Tokenizing...") |
| token_ids = tokenizer.encode(PREFIX + text, add_special_tokens=False, return_tensors="pt") |
| past_key_values = DynamicCache(config=model.config) |
| done_tokens = 0 |
| try: |
| |
| while done_tokens < max_new_tokens: |
| progress(done_tokens / max_new_tokens, desc="Generating...") |
| chunk_max_new_tokens = min(chunk_size, max_new_tokens - done_tokens) |
| |
| new_ids = model.generate( |
| token_ids.to(device), |
| max_new_tokens=chunk_max_new_tokens, |
| do_sample=True, |
| temperature=0.7, |
| top_k=50, |
| top_p=0.95, |
| past_key_values=past_key_values, |
| ) |
| |
| chunk_new_tokens = new_ids.shape[1] - token_ids.shape[1] |
| if chunk_new_tokens < chunk_max_new_tokens: |
| break |
| done_tokens += chunk_new_tokens |
| token_ids = new_ids |
|
|
| (unwrapped_new_ids,) = new_ids |
| new_text = tokenizer.decode(unwrapped_new_ids).removeprefix(PREFIX) |
| if not new_text.startswith(text): |
| yield text, "New text somehow deleted existing text!\n\n" + new_text |
| return |
| yield new_text, f"New tokens generated: {done_tokens}/{max_new_tokens}" |
| except Exception as e: |
| yield text, f"# ERROR: {e!r}" |
|
|
| DEBUG_ENABLED = False |
|
|
| if DEBUG_ENABLED: |
| def debug(cmd): |
| """Run `result.append(...)` to display values.""" |
| result = [] |
| exec(cmd, globals(), locals()) |
| return repr(result) |
| else: |
| def debug(x): |
| """Debug print the input.""" |
| return repr(x) |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Medical Text Generation") |
| gr.Markdown(f"Model in use: {MODEL}") |
| with gr.Tab("Extend"): |
| gr.Markdown("Enter some medical text, and press Generate to continue it.") |
| gr.Markdown("To allow interrupting the generation, it occurs in chunks, remembering the KV cache (only during the generation, not currently across executions).") |
| gr.Markdown("Raising the chunk size will increase latency, but might make it go faster by reducing overhead.") |
| document = gr.Code( |
| language="markdown", |
| interactive=True, |
| wrap_lines=True, |
| ) |
| max_new_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=8192, step=10, value=128) |
| chunk_size = gr.Slider(label="Streaming Chunk Size", minimum=1, maximum=100, step=1, value=5) |
| with gr.Row(): |
| generate_button = gr.Button("Generate") |
| abort_button = gr.Button("Abort") |
| generate_event = generate_button.click( |
| fn=extend, |
| inputs=[ |
| document, |
| max_new_tokens, |
| chunk_size, |
| ], |
| outputs=[ |
| document, |
| gr.Code( |
| label="Status", |
| language="markdown", |
| interactive=False, |
| wrap_lines=True, |
| ), |
| ], |
| show_progress="minimal", |
| ) |
| abort_button.click( |
| fn=None, |
| inputs=None, |
| outputs=None, |
| cancels=[generate_event], |
| ) |
| with gr.Tab("Debug"): |
| gr.Interface( |
| fn=debug, |
| inputs=[gr.Code( |
| label=debug.__doc__, |
| language="python", |
| interactive=True, |
| wrap_lines=True, |
| )], |
| outputs=[gr.Code( |
| language="python", |
| wrap_lines=True, |
| )], |
| ) |
| demo.launch() |