Pipatpong commited on
Commit
6434bf6
1 Parent(s): 6b8d2ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -2,11 +2,10 @@
2
 
3
  import gradio as gr
4
  import re
5
- import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
  checkpoint = "Pipatpong/vcm_santa"
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
11
  model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, device_map="auto", load_in_8bit=True)
12
 
@@ -17,7 +16,8 @@ def generate(text, max_length, num_return_sequences=1):
17
  return gen_text
18
 
19
 
20
- def extract_functions(text):
 
21
  function_pattern = r'def\s+(\w+)\((.*?)\):([\s\S]*?)return\s+(.*?)\n'
22
  functions = re.findall(function_pattern, text, flags=re.MULTILINE)
23
  extracted_text = []
 
2
 
3
  import gradio as gr
4
  import re
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  checkpoint = "Pipatpong/vcm_santa"
8
+ device = "cpu"
9
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
10
  model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, device_map="auto", load_in_8bit=True)
11
 
 
16
  return gen_text
17
 
18
 
19
+ def extract_functions
20
+ (text):
21
  function_pattern = r'def\s+(\w+)\((.*?)\):([\s\S]*?)return\s+(.*?)\n'
22
  functions = re.findall(function_pattern, text, flags=re.MULTILINE)
23
  extracted_text = []