MLCraftsman commited on
Commit
2bc2dcf
·
verified ·
1 Parent(s): 7c2f77f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -12,74 +12,70 @@ st.set_page_config(
12
  )
13
 
14
  # -----------------------------
15
- # Sidebar Settings
 
 
 
 
 
16
  # -----------------------------
17
  st.sidebar.title("⚙️ Settings")
18
 
19
  model_path = st.sidebar.text_input(
20
- "Model Path",
21
- value="gpt2" # Change to "./results" if using fine-tuned model
22
  )
23
 
24
- max_length = st.sidebar.slider("Max Length", 50, 500, 150)
25
- temperature = st.sidebar.slider("Temperature (Creativity)", 0.5, 1.5, 0.8)
26
  top_k = st.sidebar.slider("Top-K", 10, 100, 50)
27
  top_p = st.sidebar.slider("Top-P", 0.5, 1.0, 0.95)
28
 
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
  st.sidebar.write(f"Device: **{device.upper()}**")
31
 
32
  # -----------------------------
33
  # Title
34
  # -----------------------------
35
  st.title("🤖 Professional AI Text Generator")
36
- st.markdown(
37
- "Generate creative and grammatically correct text using a GPT-based model."
38
- )
39
 
40
  # -----------------------------
41
- # Load Model (Cached)
42
  # -----------------------------
43
  @st.cache_resource
44
- def load_model(path):
45
- tokenizer = AutoTokenizer.from_pretrained(path)
46
  tokenizer.pad_token = tokenizer.eos_token
47
 
48
- model = AutoModelForCausalLM.from_pretrained(path)
 
 
 
 
49
  model.to(device)
50
  model.eval()
51
 
52
  return tokenizer, model
53
 
 
54
  # Load model safely
55
  try:
56
  tokenizer, model = load_model(model_path)
57
  except Exception as e:
58
- st.error(f"Error loading model: {e}")
59
  st.stop()
60
 
61
  # -----------------------------
62
  # Input Area
63
  # -----------------------------
64
- col1, col2 = st.columns([2, 1])
65
-
66
- with col1:
67
- prompt = st.text_area(
68
- "Enter your prompt:",
69
- height=200,
70
- placeholder="Example: Alice was walking through the forest when..."
71
- )
72
-
73
- with col2:
74
- st.info(
75
- "Tips:\n"
76
- "- Higher temperature = more creative\n"
77
- "- Lower temperature = more accurate\n"
78
- "- Use your fine-tuned model for best results"
79
- )
80
 
81
  # -----------------------------
82
- # Generate Text
83
  # -----------------------------
84
  if st.button("✨ Generate Text", use_container_width=True):
85
 
@@ -90,15 +86,16 @@ if st.button("✨ Generate Text", use_container_width=True):
90
 
91
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
92
 
93
- output = model.generate(
94
- **inputs,
95
- max_length=max_length,
96
- temperature=temperature,
97
- top_k=top_k,
98
- top_p=top_p,
99
- do_sample=True,
100
- pad_token_id=tokenizer.eos_token_id
101
- )
 
102
 
103
  generated_text = tokenizer.decode(
104
  output[0],
@@ -108,9 +105,8 @@ if st.button("✨ Generate Text", use_container_width=True):
108
  st.subheader("Generated Output")
109
  st.write(generated_text)
110
 
111
- # Download Button
112
  st.download_button(
113
- label="📥 Download Text",
114
  data=generated_text,
115
  file_name="generated_text.txt",
116
  mime="text/plain"
 
12
  )
13
 
14
  # -----------------------------
15
+ # Device Setup (HF Spaces safe)
16
+ # -----------------------------
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # -----------------------------
20
+ # Sidebar
21
  # -----------------------------
22
  st.sidebar.title("⚙️ Settings")
23
 
24
  model_path = st.sidebar.text_input(
25
+ "Model Name / Path",
26
+ value="gpt2"
27
  )
28
 
29
+ max_new_tokens = st.sidebar.slider("Max New Tokens", 20, 300, 100)
30
+ temperature = st.sidebar.slider("Temperature", 0.5, 1.5, 0.8)
31
  top_k = st.sidebar.slider("Top-K", 10, 100, 50)
32
  top_p = st.sidebar.slider("Top-P", 0.5, 1.0, 0.95)
33
 
 
34
  st.sidebar.write(f"Device: **{device.upper()}**")
35
 
36
  # -----------------------------
37
  # Title
38
  # -----------------------------
39
  st.title("🤖 Professional AI Text Generator")
40
+ st.markdown("Generate text using Hugging Face models.")
 
 
41
 
42
  # -----------------------------
43
+ # Load Model (cached)
44
  # -----------------------------
45
  @st.cache_resource
46
+ def load_model(model_name):
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
48
  tokenizer.pad_token = tokenizer.eos_token
49
 
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ model_name,
52
+ torch_dtype=torch.float32 # safer for CPU Spaces
53
+ )
54
+
55
  model.to(device)
56
  model.eval()
57
 
58
  return tokenizer, model
59
 
60
+
61
  # Load model safely
62
  try:
63
  tokenizer, model = load_model(model_path)
64
  except Exception as e:
65
+ st.error(f"Model loading failed: {e}")
66
  st.stop()
67
 
68
  # -----------------------------
69
  # Input Area
70
  # -----------------------------
71
+ prompt = st.text_area(
72
+ "Enter your prompt:",
73
+ height=200,
74
+ placeholder="Example: Once upon a time..."
75
+ )
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # -----------------------------
78
+ # Generate Button
79
  # -----------------------------
80
  if st.button("✨ Generate Text", use_container_width=True):
81
 
 
86
 
87
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
88
 
89
+ with torch.no_grad():
90
+ output = model.generate(
91
+ **inputs,
92
+ max_new_tokens=max_new_tokens,
93
+ temperature=temperature,
94
+ top_k=top_k,
95
+ top_p=top_p,
96
+ do_sample=True,
97
+ pad_token_id=tokenizer.eos_token_id
98
+ )
99
 
100
  generated_text = tokenizer.decode(
101
  output[0],
 
105
  st.subheader("Generated Output")
106
  st.write(generated_text)
107
 
 
108
  st.download_button(
109
+ label="📥 Download",
110
  data=generated_text,
111
  file_name="generated_text.txt",
112
  mime="text/plain"