TrabbyPatty commited on
Commit
c5552a7
Β·
verified Β·
1 Parent(s): 588478f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -3,26 +3,28 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
  import os
5
 
 
6
  model_id = "TrabbyPatty/mistral-7b-instruct-finetuned-flashcards"
7
- hf_token = os.getenv("alluse") # retrieve from Space secrets
8
 
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
12
- device_map="cpu", # force CPU
13
- torch_dtype="float32" # safest for CPU
 
14
  )
15
 
16
-
17
  pipe = pipeline(
18
  "text-generation",
19
  model=model,
20
  tokenizer=tokenizer,
21
- torch_dtype=torch.float16,
22
- device_map="auto"
23
  )
24
 
25
-
26
  # === SYSTEM MESSAGE ===
27
  SYSTEM_MESSAGE = """<<SYS>>
28
  You are a strict flashcard generator.
@@ -31,8 +33,8 @@ You are a strict flashcard generator.
31
  - Always follow the requested format exactly.
32
  <</SYS>>"""
33
 
 
34
  def generate(user_input, max_new_tokens=800, temperature=0.5):
35
- # Wrap input with system instruction + prompt template
36
  prompt = (
37
  f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
38
  f"Create a variety of study aids with 10 items each, strictly using only the information provided.\n\n"
@@ -47,7 +49,7 @@ def generate(user_input, max_new_tokens=800, temperature=0.5):
47
  )
48
  return output[0]["generated_text"]
49
 
50
- # Gradio UI
51
  demo = gr.Interface(
52
  fn=generate,
53
  inputs=[
 
3
  import torch
4
  import os
5
 
6
+ # === Model ID and Token ===
7
  model_id = "TrabbyPatty/mistral-7b-instruct-finetuned-flashcards"
8
+ hf_token = os.getenv("alluse") # Hugging Face token from Space secrets
9
 
10
+ # === Load tokenizer & model with authentication ===
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
+ device_map="cpu", # force CPU
15
+ torch_dtype=torch.float32, # safest for CPU
16
+ token=hf_token
17
  )
18
 
19
+ # === Create pipeline ===
20
  pipe = pipeline(
21
  "text-generation",
22
  model=model,
23
  tokenizer=tokenizer,
24
+ torch_dtype=torch.float32,
25
+ device_map="cpu"
26
  )
27
 
 
28
  # === SYSTEM MESSAGE ===
29
  SYSTEM_MESSAGE = """<<SYS>>
30
  You are a strict flashcard generator.
 
33
  - Always follow the requested format exactly.
34
  <</SYS>>"""
35
 
36
+ # === Generation function ===
37
  def generate(user_input, max_new_tokens=800, temperature=0.5):
 
38
  prompt = (
39
  f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
40
  f"Create a variety of study aids with 10 items each, strictly using only the information provided.\n\n"
 
49
  )
50
  return output[0]["generated_text"]
51
 
52
+ # === Gradio UI ===
53
  demo = gr.Interface(
54
  fn=generate,
55
  inputs=[