E-Hospital commited on
Commit
e0f4606
1 Parent(s): e16bb65

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -11
main.py CHANGED
@@ -1,7 +1,9 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import os
 
4
 
 
5
 
6
  model = AutoModelForCausalLM.from_pretrained(
7
  "E-Hospital/open-orca-platypus-2-lora-medical",
@@ -11,7 +13,7 @@ model = AutoModelForCausalLM.from_pretrained(
11
  tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True)
12
 
13
  def ask_bot(question):
14
- input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
15
  with torch.no_grad():
16
  output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
17
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -53,7 +55,7 @@ class CustomLLM(LLM):
53
  if stop is not None:
54
  raise ValueError("stop kwargs are not permitted.")
55
 
56
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
57
  with torch.no_grad():
58
  output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
59
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -67,15 +69,6 @@ class CustomLLM(LLM):
67
 
68
 
69
 
70
- def ask_bot(question):
71
- input_ids = tokenizer.encode(question, return_tensors="pt").to('cuda')
72
- with torch.no_grad():
73
- output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
74
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
75
- response = generated_text.split("->:")[-1]
76
- return response
77
-
78
-
79
  class DbHandler():
80
  def __init__(self):
81
  self.db_con = mysql.connector.connect(
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
5
 
6
+ device = torch.device("cuda")
7
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  "E-Hospital/open-orca-platypus-2-lora-medical",
 
13
  tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True)
14
 
15
  def ask_bot(question):
16
+ input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
17
  with torch.no_grad():
18
  output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
19
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
55
  if stop is not None:
56
  raise ValueError("stop kwargs are not permitted.")
57
 
58
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
59
  with torch.no_grad():
60
  output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
61
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
69
 
70
 
71
 
 
 
 
 
 
 
 
 
 
72
  class DbHandler():
73
  def __init__(self):
74
  self.db_con = mysql.connector.connect(