cutechicken commited on
Commit
cf528b4
โ€ข
1 Parent(s): 0cdcb4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -70,28 +70,35 @@ class ModelManager:
70
  prompt += f"Assistant: {content}\n"
71
  prompt += "Assistant: "
72
 
73
- # ํ† ํฌ๋‚˜์ด์ง•
74
- input_ids = self.tokenizer(
75
  prompt,
76
  return_tensors="pt",
77
  padding=True,
78
  truncation=True,
79
  max_length=4096
80
- ).input_ids
81
-
82
- # ์ƒ์„ฑ
83
- outputs = self.model.generate(
84
- input_ids,
85
- max_new_tokens=max_tokens,
86
- do_sample=True,
87
- temperature=temperature,
88
- top_p=top_p,
89
- pad_token_id=self.tokenizer.pad_token_id,
90
- eos_token_id=self.tokenizer.eos_token_id,
91
- num_return_sequences=1
92
  )
 
 
 
 
93
 
94
- # ๋””์ฝ”๋”ฉ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  generated_text = self.tokenizer.decode(
96
  outputs[0][input_ids.shape[1]:],
97
  skip_special_tokens=True
 
70
  prompt += f"Assistant: {content}\n"
71
  prompt += "Assistant: "
72
 
73
+ # ํ† ํฌ๋‚˜์ด์ง• ๋ฐ device ์„ค์ •
74
+ inputs = self.tokenizer(
75
  prompt,
76
  return_tensors="pt",
77
  padding=True,
78
  truncation=True,
79
  max_length=4096
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
+
82
+ # ๋ชจ๋“  ํ…์„œ๋ฅผ GPU๋กœ ์ด๋™
83
+ input_ids = inputs.input_ids.to(self.model.device)
84
+ attention_mask = inputs.attention_mask.to(self.model.device)
85
 
86
+ # ์ƒ์„ฑ
87
+ with torch.no_grad():
88
+ outputs = self.model.generate(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ max_new_tokens=max_tokens,
92
+ do_sample=True,
93
+ temperature=temperature,
94
+ top_p=top_p,
95
+ pad_token_id=self.tokenizer.pad_token_id,
96
+ eos_token_id=self.tokenizer.eos_token_id,
97
+ num_return_sequences=1
98
+ )
99
+
100
+ # ๋””์ฝ”๋”ฉ ์ „์— CPU๋กœ ์ด๋™
101
+ outputs = outputs.cpu()
102
  generated_text = self.tokenizer.decode(
103
  outputs[0][input_ids.shape[1]:],
104
  skip_special_tokens=True