jeffeux commited on
Commit
8446129
1 Parent(s): a31cbcf
Files changed (2) hide show
  1. .gitignore +1 -0
  2. main.py +15 -13
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  # environment
2
  bloom_demo
 
 
1
  # environment
2
  bloom_demo
3
+ tutorial-env
main.py CHANGED
@@ -42,17 +42,19 @@ def model_init():
42
 
43
  tokenizer, model = model_init()
44
 
45
- # ===================== INPUT ====================== #
46
- # prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
47
- prompt = st.text_input("Prompt: ")
 
48
 
49
- # =================== INFERENCE ==================== #
50
- if prompt:
51
- with torch.no_grad():
52
- [texts_out] = model.generate(
53
- **tokenizer(
54
- prompt, return_tensors="pt"
55
- ).to(device))
56
- output_text = tokenizer.decode(texts_out)
57
- st.markdown(output_text)
58
-
 
 
42
 
43
  tokenizer, model = model_init()
44
 
45
+ try:
46
+ # ===================== INPUT ====================== #
47
+ # prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
48
+ prompt = st.text_input("Prompt: ")
49
 
50
+ # =================== INFERENCE ==================== #
51
+ if prompt:
52
+ with torch.no_grad():
53
+ [texts_out] = model.generate(
54
+ **tokenizer(
55
+ prompt, return_tensors="pt"
56
+ ).to(device))
57
+ output_text = tokenizer.decode(texts_out)
58
+ st.markdown(output_text)
59
+ except Exception as err:
60
+ st.write(str(err))