import os import spaces import torch import gradio as gr # cpu zero = torch.Tensor([0]).cuda() print(zero.device) # <-- 'cpu' 🤔 # gpu model = None @spaces.GPU def greet(prompts, separator): # print(zero.device) # <-- 'cuda:0' 🤗 from vllm import SamplingParams, LLM from transformers.utils import move_cache from huggingface_hub import snapshot_download, login global model if model is None: LLM_MODEL_ID = "DoctorSlimm/trim-music-31" # LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2" # LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2" os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' fp = snapshot_download(LLM_MODEL_ID, token=os.getenv('HF_TOKEN'), revision='main') move_cache() model = LLM(fp) sampling_params = dict( temperature = 0.01, ignore_eos = False, max_tokens = int(512 * 2) ) sampling_params = SamplingParams(**sampling_params) multi_prompt = False separator = separator.strip() if separator in prompts: multi_prompt = True prompts = prompts.split(separator) else: prompts = [prompts] for idx, pt in enumerate(prompts): print() print(f'[{idx}]:') print(pt) model_outputs = model.generate(prompts, sampling_params) generations = [] for output in model_outputs: for outputs in output.outputs: generations.append(outputs.text) if multi_prompt: return separator.join(generations) return generations[0] ## make predictions via api ## # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app demo = gr.Interface( fn=greet, inputs=[ gr.Text( value='hello sir!bonjour madame...', placeholder='hello sir!bonjour madame...', label='list of prompts separated by separator' ), gr.Text( value='', placeholder='', label='separator for your prompts' )], outputs=gr.Text() ) demo.launch(share=True)