Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import torch | |
| import transformers | |
| import einops | |
| ### | |
| from typing import Any, Dict, Tuple | |
| import warnings | |
| import datetime | |
| import os | |
| from threading import Event, Thread | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
| import config | |
| config.init_device="meta" | |
| INSTRUCTION_KEY = "### Instruction:" | |
| RESPONSE_KEY = "### Response:" | |
| END_KEY = "### End" | |
| INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
| PROMPT_FOR_GENERATION_FORMAT = """{intro} | |
| {instruction_key} | |
| {instruction} | |
| {response_key} | |
| """.format( | |
| intro=INTRO_BLURB, | |
| instruction_key=INSTRUCTION_KEY, | |
| instruction="{instruction}", | |
| response_key=RESPONSE_KEY, | |
| ) | |
| ## | |
| from InstructionTextGenerationPipeline import * | |
| from timeit import default_timer as timer | |
| import time | |
| import datetime | |
| from datetime import datetime | |
| import json | |
| # create some interactive controls | |
| import sys | |
| import os | |
| import os.path as osp | |
| import pprint | |
| pp = pprint.PrettyPrinter(indent=4) | |
| LIBRARY_PATH = "/home/ec2-user/workspace/Notebooks/lib" | |
| module_path = os.path.abspath(os.path.join(LIBRARY_PATH)) | |
| if module_path not in sys.path: | |
| sys.path.append(module_path) | |
| print (f"sys.path : {sys.path}") | |
| def complete(state="complete"): | |
| print(f"\nCell {state}") | |
| complete(state='imports done') | |
| complete(state="start generate") | |
| name = 'mosaicml/mpt-7b-instruct' | |
| config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True) | |
| config.attn_config['attn_impl'] = 'torch' | |
| config.init_device = 'cuda:0' # For fast initialization directly on GPU! | |
| generate = InstructionTextGenerationPipeline( | |
| name, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| config=config, | |
| ) | |
| stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) | |
| complete(state="Model generated") | |
| # Define a custom stopping criteria | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| for stop_id in stop_token_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| def process_stream(instruction, temperature, top_p, top_k, max_new_tokens): | |
| # Tokenize the input | |
| input_ids = generate.tokenizer( | |
| generate.format_instruction(instruction), return_tensors="pt" | |
| ).input_ids | |
| input_ids = input_ids.to(generate.model.device) | |
| # Initialize the streamer and stopping criteria | |
| streamer = TextIteratorStreamer( | |
| generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| stop = StopOnTokens() | |
| if temperature < 0.1: | |
| temperature = 0.0 | |
| do_sample = False | |
| else: | |
| do_sample = True | |
| gkw = { | |
| **generate.generate_kwargs, | |
| **{ | |
| "input_ids": input_ids, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "do_sample": do_sample, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "streamer": streamer, | |
| "stopping_criteria": StoppingCriteriaList([stop]), | |
| }, | |
| } | |
| response = '' | |
| def generate_and_signal_complete(): | |
| generate.model.generate(**gkw) | |
| t1 = Thread(target=generate_and_signal_complete) | |
| t1.start() | |
| for new_text in streamer: | |
| response += new_text | |
| return response | |
| gr.close_all() | |
| def tester(uPrompt, max_new_tokens, temperature, top_k, top_p): | |
| salutation = uPrompt | |
| response = process_stream(uPrompt, temperature, top_p, top_k, max_new_tokens) | |
| results = f"{salutation} max_new_tokens{max_new_tokens}; temperature{temperature}; top_k{top_k}; top_p{top_p}; " | |
| return response | |
| import torch | |
| import transformers | |
| demo = gr.Interface( | |
| fn=tester, | |
| inputs=[gr.Textbox(label="Prompt",info="Prompt",lines=3,value="Provide Prompt"), | |
| gr.Slider(256, 3072,value=1024, step=256, label="Tokens" ), | |
| gr.Slider(0.0, 1.0, value=0.1, step=0.1, label='temperature:'), | |
| gr.Slider(0, 1, value=0, step=1, label='top_k:'), | |
| gr.Slider(0.0, 1.0, value=0.05, step=0.05, label='top_p:') | |
| ], | |
| outputs=["text"], | |
| title="Mosaic MPT-7B", | |
| ) | |
| demo.launch(share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |
| # Note on how we can run on SSL | |
| # See:https://github.com/gradio-app/gradio/issues/563 | |
| # a = gr.Interface(lambda x:x, "image", "image", examples=["lion.jpg"]).launch( | |
| # share=False, ssl_keyfile="key.pem", ssl_certfile="cert.pem") | |
| # seems like we need an appropriate NON SELF SIGNED cert that the customer will accept on their net | |