E-Hospital commited on
Commit
8fc450b
1 Parent(s): 9110acb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -9
main.py CHANGED
@@ -23,7 +23,7 @@ tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", tr
23
  def ask_bot(question):
24
  input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
25
  with torch.no_grad():
26
- output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
27
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
28
  response = generated_text.split("->:")[-1]
29
  return response
@@ -65,7 +65,7 @@ class CustomLLM(LLM):
65
 
66
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
67
  with torch.no_grad():
68
- output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50)
69
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
  response = generated_text.split("->:")[-1]
71
  return response
@@ -156,13 +156,13 @@ def chatbot(patient_id, user_data: dict=None):
156
  human_input = prompt + user_input + " ->:"
157
  human_text = user_input.replace("'", "")
158
  response = llm._call(human_input)
159
- response = response.replace("'", "")
160
- memory.save_context({"input": user_input}, {"output": response})
161
- summary = memory.load_memory_variables({})
162
- ai_text = response.replace("'", "")
163
- memory.save_context({"input": user_input}, {"output": ai_text})
164
- summary = memory.load_memory_variables({})
165
- db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
166
  db.close_db()
167
  return {"response": response}
168
  finally:
 
23
  def ask_bot(question):
24
  input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
25
  with torch.no_grad():
26
+ output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50)
27
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
28
  response = generated_text.split("->:")[-1]
29
  return response
 
65
 
66
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
67
  with torch.no_grad():
68
+ output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50)
69
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
  response = generated_text.split("->:")[-1]
71
  return response
 
156
  human_input = prompt + user_input + " ->:"
157
  human_text = user_input.replace("'", "")
158
  response = llm._call(human_input)
159
+ # response = response.replace("'", "")
160
+ # memory.save_context({"input": user_input}, {"output": response})
161
+ # summary = memory.load_memory_variables({})
162
+ # ai_text = response.replace("'", "")
163
+ # memory.save_context({"input": user_input}, {"output": ai_text})
164
+ # summary = memory.load_memory_variables({})
165
+ # db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", "")))
166
  db.close_db()
167
  return {"response": response}
168
  finally: