hruday96 commited on
Commit
4230190
1 Parent(s): f439464

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -1,19 +1,19 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
  # App header
5
  st.header("Know Your Medicine - Multiplication Table Generator")
6
 
7
  # Load the model and tokenizer
8
  @st.cache_resource
9
- def load_model_direct():
10
- model_name = "meta-llama/Llama-3.2-1B"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name)
13
  return model, tokenizer
14
 
15
  # Load the model
16
- model, tokenizer = load_model_direct()
17
 
18
  # Input for the number to generate the multiplication table
19
  number = st.number_input("Enter a number:", min_value=1, max_value=100, value=5)
@@ -25,8 +25,8 @@ prompt = f"Give me the multiplication table of {number} up to 12."
25
  if st.button("Generate Multiplication Table"):
26
  # Tokenize the input prompt
27
  tokenized_input = tokenizer(prompt, return_tensors="pt")
28
- input_ids = tokenized_input["input_ids"].cuda() # If running on GPU
29
- attention_mask = tokenized_input["attention_mask"].cuda() # If running on GPU
30
 
31
  # Generate the response from the model
32
  response_token_ids = model.generate(
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # App header
5
  st.header("Know Your Medicine - Multiplication Table Generator")
6
 
7
  # Load the model and tokenizer
8
  @st.cache_resource
9
+ def load_model():
10
+ model_name = "gpt2" # Using GPT-2 for faster builds
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name)
13
  return model, tokenizer
14
 
15
  # Load the model
16
+ model, tokenizer = load_model()
17
 
18
  # Input for the number to generate the multiplication table
19
  number = st.number_input("Enter a number:", min_value=1, max_value=100, value=5)
 
25
  if st.button("Generate Multiplication Table"):
26
  # Tokenize the input prompt
27
  tokenized_input = tokenizer(prompt, return_tensors="pt")
28
+ input_ids = tokenized_input["input_ids"] # Using CPU for simplicity
29
+ attention_mask = tokenized_input["attention_mask"] # Using CPU for simplicity
30
 
31
  # Generate the response from the model
32
  response_token_ids = model.generate(