kjozsa commited on
Commit
8a8ca58
1 Parent(s): 4577a87

transformers chat

Browse files
Files changed (3) hide show
  1. app.py +4 -20
  2. ollamachat.py +24 -0
  3. transformerschat.py +37 -0
app.py CHANGED
@@ -1,28 +1,11 @@
1
  import re
2
 
3
- import ollama
4
  import streamlit as st
5
  from loguru import logger
 
 
6
 
7
- available_models = sorted([x['model'] for x in ollama.list()['models']], key=lambda x: (not x.startswith("openhermes"), x))
8
-
9
-
10
- def ask(model, system_prompt, pre_prompt, question):
11
- messages = [
12
- {
13
- 'role': 'system',
14
- 'content': f"{system_prompt} {pre_prompt}",
15
- },
16
- {
17
- 'role': 'user',
18
- 'content': f"{question}",
19
- },
20
- ]
21
- logger.debug(f"<< {model} << {question}")
22
- response = ollama.chat(model=model, messages=messages)
23
- answer = response['message']['content']
24
- logger.debug(f">> {model} >> {answer}")
25
- return answer
26
 
27
 
28
  class Actor:
@@ -70,6 +53,7 @@ def main():
70
  actor = target(question)
71
 
72
 
 
73
  def target(question) -> Actor:
74
  try:
75
  role = re.split(r'\s|,|:', question.strip())[0].strip()
 
1
  import re
2
 
 
3
  import streamlit as st
4
  from loguru import logger
5
+ # from ollamachat import ask, models
6
+ from transformerschat import ask, models
7
 
8
+ available_models = models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class Actor:
 
53
  actor = target(question)
54
 
55
 
56
+ # noinspection PyTypeChecker
57
  def target(question) -> Actor:
58
  try:
59
  role = re.split(r'\s|,|:', question.strip())[0].strip()
ollamachat.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ import ollama
3
+
4
+
5
+ def models():
6
+ return sorted([x['model'] for x in ollama.list()['models']], key=lambda x: (not x.startswith("openhermes"), x))
7
+
8
+
9
+ def ask(model, system_prompt, pre_prompt, question):
10
+ messages = [
11
+ {
12
+ 'role': 'system',
13
+ 'content': f"{system_prompt} {pre_prompt}",
14
+ },
15
+ {
16
+ 'role': 'user',
17
+ 'content': f"{question}",
18
+ },
19
+ ]
20
+ logger.debug(f"<< {model} << {question}")
21
+ response = ollama.chat(model=model, messages=messages)
22
+ answer = response['message']['content']
23
+ logger.debug(f">> {model} >> {answer}")
24
+ return answer
transformerschat.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from loguru import logger
4
+ import spaces
5
+
6
+
7
+ def models():
8
+ return ["openhermes"]
9
+
10
+
11
+ def load():
12
+ torch.set_default_device("cuda")
13
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
14
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
15
+ return (model, tokenizer)
16
+
17
+
18
+ model, tokenizer = load()
19
+
20
+
21
+ def ask(model, system_prompt, pre_prompt, question):
22
+ messages = [
23
+ {
24
+ 'role': 'system',
25
+ 'content': f"{system_prompt} {pre_prompt}",
26
+ },
27
+ {
28
+ 'role': 'user',
29
+ 'content': f"{question}",
30
+ },
31
+ ]
32
+ logger.debug(f"<< {model} << {question}")
33
+ inputs = tokenizer(question, return_tensors="pt", return_attention_mask=False)
34
+ outputs = model.generate(**inputs, max_length=200)
35
+ answer = tokenizer.batch_decode(outputs)[0]
36
+ logger.debug(f">> {model} >> {answer}")
37
+ return answer