tolulope commited on
Commit
0906c57
·
verified ·
1 Parent(s): 55f45b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -2
app.py CHANGED
@@ -3,9 +3,51 @@ import transformers
3
  import torch
4
  from peft import PeftModel
5
  import os
 
 
 
 
6
 
 
 
 
 
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  model_id = "JerniganLab/interviews-and-qa"
10
  base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
11
 
@@ -22,6 +64,17 @@ pipeline = transformers.pipeline(
22
 
23
  pipeline.model = PeftModel.from_pretrained(llama_model, model_id)
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  # def chat_function(message, history, system_prompt, max_new_tokens, temperature):
26
  # messages = [{"role":"system","content":system_prompt},
27
  # {"role":"user", "content":message}]
@@ -59,6 +112,7 @@ def chat_function(message, history, max_new_tokens, temperature):
59
  do_sample = True,
60
  temperature = temperature + 0.1,
61
  top_p = 0.9,)
 
62
  return outputs[0]["generated_text"][len(prompt):]
63
 
64
  """
@@ -82,10 +136,12 @@ demo = gr.ChatInterface(
82
  additional_inputs=[
83
  gr.Slider(100,4000, label="Max New Tokens"),
84
  gr.Slider(0,1, label="Temperature")
85
- ]
 
 
86
  )
87
 
88
 
89
 
90
  if __name__ == "__main__":
91
- demo.launch()
 
3
  import torch
4
  from peft import PeftModel
5
  import os
6
+ import csv
7
+ import huggingface_hub
8
+ from huggingface_hub import Repository, hf_hub_download, upload_file
9
+ from datetime import datetime
10
 
11
+ DATASET_REPO_URL = "https://huggingface.co/datasets/JerniganLab/chat-data"
12
+ DATASET_REPO_ID = "JerniganLab/chat-data"
13
+ DATA_FILENAME = "data.csv"
14
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
 
17
+
18
+ HF_TOKEN = os.environ.get("HF_TOKEN")
19
+
20
+ # overriding/appending to the gradio template
21
+ SCRIPT = """
22
+ <script>
23
+ if (!window.hasBeenRun) {
24
+ window.hasBeenRun = true;
25
+ console.log("should only happen once");
26
+ document.querySelector("button.submit").click();
27
+ }
28
+ </script>
29
+ """
30
+ with open(os.path.join(gr.routes.STATIC_TEMPLATE_LIB, "frontend", "index.html"), "a") as f:
31
+ f.write(SCRIPT)
32
+
33
+ try:
34
+ hf_hub_download(
35
+ repo_id=DATASET_REPO_ID,
36
+ filename=DATA_FILENAME,
37
+ cache_dir=DATA_DIRNAME,
38
+ repo_type='dataset',
39
+ force_filename=DATA_FILENAME
40
+ )
41
+ except:
42
+ print("file not found")
43
+
44
+ repo = Repository(
45
+ local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
46
+ )
47
+
48
+
49
+
50
+
51
  model_id = "JerniganLab/interviews-and-qa"
52
  base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
53
 
 
64
 
65
  pipeline.model = PeftModel.from_pretrained(llama_model, model_id)
66
 
67
+ def store_message(message: str, system_prompt: str, response: str):
68
+ if response and message:
69
+ with open(DATA_FILE, "a") as csvfile:
70
+ writer = csv.DictWriter(csvfile, fieldnames=["message","system_prompt","response","time"])
71
+ writer.writerow(
72
+ {"message": message, "system_prompt": system_prompt, "response": response, "time": str(datetime.now())}
73
+ )
74
+ commit_url = repo.push_to_hub()
75
+ # return generate_html()
76
+
77
+
78
  # def chat_function(message, history, system_prompt, max_new_tokens, temperature):
79
  # messages = [{"role":"system","content":system_prompt},
80
  # {"role":"user", "content":message}]
 
112
  do_sample = True,
113
  temperature = temperature + 0.1,
114
  top_p = 0.9,)
115
+ store_message(message, system_prompt, outputs[0]["generated_text"][len(prompt):])
116
  return outputs[0]["generated_text"][len(prompt):]
117
 
118
  """
 
136
  additional_inputs=[
137
  gr.Slider(100,4000, label="Max New Tokens"),
138
  gr.Slider(0,1, label="Temperature")
139
+ ],
140
+ type="messages",
141
+ save_history=True,
142
  )
143
 
144
 
145
 
146
  if __name__ == "__main__":
147
+ demo.launch()