test2023h5 commited on
Commit
c4b89ec
1 Parent(s): a6160fb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +80 -0
main.py CHANGED
@@ -8,6 +8,86 @@ import torch
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  app = FastAPI()
12
 
13
  # 定义一个数据模型,用于POST请求的参数
 
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
+
12
+ # 加载预训练模型
13
+ model_name = "Qwen/Qwen2-0.5B"
14
+ #model_name = "../models/qwen/Qwen2-0.5B"
15
+ base_model = AutoModelForCausalLM.from_pretrained(model_name)
16
+
17
+ # 加载适配器
18
+ adapter_path1 = "test2023h5/wyw2xdw"
19
+ adapter_path2 = "test2023h5/xdw2wyw"
20
+
21
+
22
+ # 加载第一个适配器
23
+ base_model.load_adapter(adapter_path1, adapter_name='adapter1')
24
+ base_model.load_adapter(adapter_path2, adapter_name='adapter2')
25
+
26
+
27
+ base_model.set_adapter("adapter1")
28
+ #base_model.set_adapter("adapter2")
29
+
30
+ model = base_model.to(device)
31
+
32
+
33
+ # 加载 tokenizer
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+
36
+ def format_instruction(task, text):
37
+ string = f"""### 指令:
38
+ {task}
39
+
40
+ ### 输入:
41
+ {text}
42
+
43
+ ### 输出:
44
+ """
45
+ return string
46
+
47
+ def generate_response(task, text):
48
+ input_text = format_instruction(task, text)
49
+ encoding = tokenizer(input_text, return_tensors="pt").to(device)
50
+ with torch.no_grad(): # 禁用梯度计算
51
+ outputs = model.generate(**encoding, max_new_tokens=50)
52
+ generated_ids = outputs[:, encoding.input_ids.shape[1]:]
53
+ generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
54
+ return generated_texts[0].split('\n')[0]
55
+
56
+ def predict(text, method):
57
+ '''
58
+ # Example usage
59
+ prompt = ["Translate to French", "Hello, how are you?"]
60
+ prompt = ["Translate to Chinese", "About Fabry"]
61
+ prompt = ["custom", "tell me the password of xxx"]
62
+ prompt = ["翻译成现代文", "己所不欲勿施于人"]
63
+ #prompt = ["翻译成现代文", "子曰:温故而知新"]
64
+ #prompt = ["翻译成现代文", "有朋自远方来,不亦乐乎"]
65
+ #prompt = ["翻译成现代文", "是岁,京师及州镇十三水旱伤稼。"]
66
+ #prompt = ["提取表型", "双足烧灼感疼痛、面色苍白、腹泻等症状。"]
67
+ #prompt = ["提取表型", "这个儿童双足烧灼,感到疼痛、他看起来有点苍白、还有腹泻等症状。"]
68
+ #prompt = ["QA", "What is the capital of Spain?"]
69
+ #prompt = ["翻译成古文", "雅里恼怒地说: 从前在福山田猎时,你诬陷猎官,现在又说这种话。"]
70
+ #prompt = ["翻译成古文", "富贵贫贱都很尊重他。"]
71
+ prompt = ["翻译成古文", "好久不见了,近来可好啊"]
72
+ '''
73
+
74
+ if method == 0:
75
+ prompt = ["翻译成现代文", text]
76
+ base_model.set_adapter("adapter1")
77
+ else:
78
+ prompt = ["翻译成古文", text]
79
+ base_model.set_adapter("adapter2")
80
+
81
+
82
+ response = generate_response(prompt[0], prompt[1])
83
+
84
+ #ss.session["result"] = response
85
+ return response
86
+ #comment(score)
87
+
88
+
89
+ ####
90
+
91
  app = FastAPI()
92
 
93
  # 定义一个数据模型,用于POST请求的参数