File size: 1,937 Bytes
89d02ed
505eb00
78c11dd
89d02ed
d3ef534
2b38d1f
73251f3
d3ef534
 
 
 
 
 
 
 
89d02ed
78c11dd
d3ef534
 
 
 
 
89d02ed
d3ef534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505eb00
d3ef534
 
 
 
 
 
 
 
89d02ed
78c11dd
89d02ed
78c11dd
89d02ed
d3ef534
 
89d02ed
 
 
 
 
d3ef534
89d02ed
 
2b38d1f
 
89d02ed
 
 
d3ef534
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import AutoTokenizer, LlamaForCausalLM
import torch

# 使用 UrbanGPT 模型
model_name = "bjdwh/UrbanGPT"

# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

def generate_response(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
):
    # 格式化输入
    input_text = message
    if history:
        input_text = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in history]) + f"\nUser: {message}"
    
    # 编码输入
    inputs = tokenizer(input_text, return_tensors="pt", padding=True)
    
    # 生成回复
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # 如果有历史对话,需要提取最后的回复
    if history:
        response = response.split("Assistant: ")[-1].strip()
    
    yield response

# 创建 Gradio 界面
demo = gr.ChatInterface(
    generate_response,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="生成最大长度"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (核采样)",
        ),
    ],
    title="UrbanGPT 聊天助手",
    description="这是一个基于 UrbanGPT 的中文城市规划对话模型",
)

if __name__ == "__main__":
    demo.launch()