import gradio as gr # from huggingface_hub import InferenceClient # """ # For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference # """ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") from transformers import AutoModelForCausalLM, AutoTokenizer import torch class ChatClient: def __init__(self, model_path): """ 初始化客户端,加载模型和分词器到 GPU(如果可用)。 """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device) self.model.eval() # 设置为评估模式 async def chat_completion(self, messages, max_tokens, stream=False, temperature=1.0, top_p=1.0): """ 生成对话回复。 """ # 将所有输入消息合并为一个字符串 input_text = messages print(input_text) # 使用分词器处理输入文本 inputs = self.tokenizer(input_text, return_tensors='pt').to(self.device) # 设置生成的参数 gen_kwargs = { "max_length": inputs['input_ids'].shape[1] + max_tokens, "temperature": temperature, "top_p": top_p, "do_sample": True } # 使用生成器生成文本 output_sequences = self.model.generate(**inputs, **gen_kwargs) # 解码生成的文本 # result_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) # yield result_text # 解码生成的文本 for sequence in output_sequences: result_text = self.tokenizer.decode(sequence, skip_special_tokens=True) await anyio.sleep(0) # Yield control, simulating asynchronous operation yield result_text # 创建客户端实例,指定模型路径 model_path = 'model/v3/' client = ChatClient(model_path) async def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): # messages = [{"role": "system", "content": system_message}] # # for val in history: # if val[0]: # messages.append({"role": "user", "content": val[0]}) # if val[1]: # messages.append({"role": "assistant", "content": val[1]}) # # messages.append({"role": "user", "content": message}) messages = system_message + message response = "" async for message in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): # print(message) # token = message # #token = message.choices[0].delta.content # response += token # yield response yield message """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="Yahoo!ショッピングについての質問を回答してください。", label="System message"), gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": demo.launch()