Spaces:
Runtime error
Runtime error
michal
commited on
Commit
•
5b01087
1
Parent(s):
5f4adea
the brain splice commit !!
Browse files- app.py +22 -2
- audios/tempfile.mp3 +0 -0
- greg_funcs.py +39 -1
app.py
CHANGED
@@ -42,7 +42,7 @@ from torch import tensor as torch_tensor
|
|
42 |
from datasets import load_dataset
|
43 |
|
44 |
|
45 |
-
from greg_funcs import
|
46 |
|
47 |
"""# import models"""
|
48 |
|
@@ -491,6 +491,10 @@ class ChatWrapper:
|
|
491 |
):
|
492 |
"""Execute the chat functionality."""
|
493 |
self.lock.acquire()
|
|
|
|
|
|
|
|
|
494 |
try:
|
495 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
496 |
print("inp: " + inp)
|
@@ -504,7 +508,10 @@ class ChatWrapper:
|
|
504 |
output = AUTH_ERR_MSG
|
505 |
hidden_text = output
|
506 |
|
507 |
-
|
|
|
|
|
|
|
508 |
if chain:
|
509 |
# Set OpenAI key
|
510 |
import openai
|
@@ -515,12 +522,25 @@ class ChatWrapper:
|
|
515 |
chain, inp, capture_hidden_text=trace_chain)
|
516 |
else:
|
517 |
output, hidden_text = inp, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
|
519 |
output = transform_text(output, express_chain, num_words, formality, anticipation_level, joy_level,
|
520 |
trust_level,
|
521 |
fear_level, surprise_level, sadness_level, disgust_level, anger_level,
|
522 |
lang_level, translate_to, literary_style)
|
523 |
|
|
|
|
|
524 |
text_to_display = output
|
525 |
if trace_chain:
|
526 |
text_to_display = hidden_text + "\n\n" + output
|
|
|
42 |
from datasets import load_dataset
|
43 |
|
44 |
|
45 |
+
from greg_funcs import get_llm_response
|
46 |
|
47 |
"""# import models"""
|
48 |
|
|
|
491 |
):
|
492 |
"""Execute the chat functionality."""
|
493 |
self.lock.acquire()
|
494 |
+
|
495 |
+
|
496 |
+
|
497 |
+
# import ipdb; ipdb.set_trace()
|
498 |
try:
|
499 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
500 |
print("inp: " + inp)
|
|
|
508 |
output = AUTH_ERR_MSG
|
509 |
hidden_text = output
|
510 |
|
511 |
+
|
512 |
+
output = get_llm_response(inp)
|
513 |
+
|
514 |
+
"""
|
515 |
if chain:
|
516 |
# Set OpenAI key
|
517 |
import openai
|
|
|
522 |
chain, inp, capture_hidden_text=trace_chain)
|
523 |
else:
|
524 |
output, hidden_text = inp, None
|
525 |
+
"""
|
526 |
+
print("original output", output)
|
527 |
+
print("using these knobs:",
|
528 |
+
(
|
529 |
+
formality, anticipation_level, joy_level,
|
530 |
+
trust_level,
|
531 |
+
fear_level, surprise_level, sadness_level, disgust_level, anger_level,
|
532 |
+
lang_level, translate_to, literary_style
|
533 |
+
)
|
534 |
+
|
535 |
+
)
|
536 |
|
537 |
output = transform_text(output, express_chain, num_words, formality, anticipation_level, joy_level,
|
538 |
trust_level,
|
539 |
fear_level, surprise_level, sadness_level, disgust_level, anger_level,
|
540 |
lang_level, translate_to, literary_style)
|
541 |
|
542 |
+
print("transformed output", output)
|
543 |
+
|
544 |
text_to_display = output
|
545 |
if trace_chain:
|
546 |
text_to_display = hidden_text + "\n\n" + output
|
audios/tempfile.mp3
CHANGED
Binary files a/audios/tempfile.mp3 and b/audios/tempfile.mp3 differ
|
|
greg_funcs.py
CHANGED
@@ -4,6 +4,13 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
|
4 |
from torch import tensor as torch_tensor
|
5 |
from datasets import load_dataset
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
"""# import models"""
|
@@ -25,7 +32,7 @@ dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
|
|
25 |
dataset_embed_pd = dataset_embed.to_pandas()
|
26 |
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
27 |
|
28 |
-
def
|
29 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
30 |
question_embedding = question_embedding #.cuda()
|
31 |
hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k)
|
@@ -45,3 +52,34 @@ def greg_search(query, passages = mypassages, doc_embedding = mycorpus_embedding
|
|
45 |
# for hit in hits[0:3]:
|
46 |
# print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from torch import tensor as torch_tensor
|
5 |
from datasets import load_dataset
|
6 |
|
7 |
+
from langchain.llms import OpenAI
|
8 |
+
from langchain.docstore.document import Document
|
9 |
+
|
10 |
+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
|
15 |
|
16 |
"""# import models"""
|
|
|
32 |
dataset_embed_pd = dataset_embed.to_pandas()
|
33 |
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
34 |
|
35 |
+
def search(query, passages = mypassages, doc_embedding = mycorpus_embeddings, top_k=20, top_n = 1):
|
36 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
37 |
question_embedding = question_embedding #.cuda()
|
38 |
hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k)
|
|
|
52 |
# for hit in hits[0:3]:
|
53 |
# print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
|
54 |
|
55 |
+
|
56 |
+
|
57 |
+
def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
|
58 |
+
predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
|
59 |
+
prediction_text = []
|
60 |
+
for hit in predictions:
|
61 |
+
page_content = passages[hit['corpus_id']]
|
62 |
+
metadata = {"source": hit['corpus_id']}
|
63 |
+
result = Document(page_content=page_content, metadata=metadata)
|
64 |
+
prediction_text.append(result)
|
65 |
+
return prediction_text
|
66 |
+
|
67 |
+
|
68 |
+
chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff")
|
69 |
+
|
70 |
+
|
71 |
+
def get_llm_response(message):
|
72 |
+
mydocs = get_text_fmt(message)
|
73 |
+
response = chain_qa.run(input_documents=mydocs, question=message)
|
74 |
+
return response
|
75 |
+
|
76 |
+
def chat(message, history):
|
77 |
+
history = history or []
|
78 |
+
message = message.lower()
|
79 |
+
|
80 |
+
response = get_llm_response(message)
|
81 |
+
history.append((message, response))
|
82 |
+
return history, history
|
83 |
+
|
84 |
+
|
85 |
+
|