HubertMJ commited on
Commit
4644de9
1 Parent(s): fbb2963

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple, Any
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ roles: List[str]
17
+ messages: List[List[str]]
18
+ offset: int
19
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
20
+ sep: str = "###"
21
+ sep2: str = None
22
+
23
+ # Used for gradio server
24
+ skip_next: bool = False
25
+ conv_id: Any = None
26
+
27
+ def get_prompt(self):
28
+ if self.sep_style == SeparatorStyle.SINGLE:
29
+ ret = self.system
30
+ for role, message in self.messages:
31
+ if message:
32
+ ret += self.sep + " " + role + ": " + message
33
+ else:
34
+ ret += self.sep + " " + role + ":"
35
+ return ret
36
+ elif self.sep_style == SeparatorStyle.TWO:
37
+ seps = [self.sep, self.sep2]
38
+ ret = self.system + seps[0]
39
+ for i, (role, message) in enumerate(self.messages):
40
+ if message:
41
+ ret += role + ": " + message + seps[i % 2]
42
+ else:
43
+ ret += role + ":"
44
+ return ret
45
+ else:
46
+ raise ValueError(f"Invalid style: {self.sep_style}")
47
+
48
+ def append_message(self, role, message):
49
+ self.messages.append([role, message])
50
+
51
+ def to_gradio_chatbot(self):
52
+ ret = []
53
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
54
+ if i % 2 == 0:
55
+ ret.append([msg, None])
56
+ else:
57
+ ret[-1][-1] = msg
58
+ return ret
59
+
60
+ def copy(self):
61
+ return Conversation(
62
+ system=self.system,
63
+ roles=self.roles,
64
+ messages=[[x, y] for x, y in self.messages],
65
+ offset=self.offset,
66
+ sep_style=self.sep_style,
67
+ sep=self.sep,
68
+ sep2=self.sep2,
69
+ conv_id=self.conv_id)
70
+
71
+ def dict(self):
72
+ return {
73
+ "system": self.system,
74
+ "roles": self.roles,
75
+ "messages": self.messages,
76
+ "offset": self.offset,
77
+ "sep": self.sep,
78
+ "sep2": self.sep2,
79
+ "conv_id": self.conv_id,
80
+ }
81
+
82
+
83
+
84
+ conv = Conversation(
85
+ system="A chat between a curious user and an artificial intelligence assistant. "
86
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
87
+ roles=("USER", "ASSISTANT"),
88
+ messages=[],
89
+ offset=0,
90
+ sep_style=SeparatorStyle.TWO,
91
+ sep=" ",
92
+ sep2="</s>",
93
+ )
94
+
95
+ conv.append_message(conv.roles[0], "Why would Microsoft take this down?")
96
+ conv.append_message(conv.roles[1], None)
97
+ prompt = conv.get_prompt()
98
+
99
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
+
101
+ result = model.generate(**inputs, max_new_tokens=1000)
102
+ generated_ids = result[0]
103
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
104
+ print(generated_text)