jatnikonm commited on
Commit
135d706
·
1 Parent(s): 712a6ef

kode post-ocr

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +89 -28
  3. app.yml +1 -0
  4. requirements.txt +7 -2
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # .gitignore
2
+ .env
app.py CHANGED
@@ -1,36 +1,97 @@
 
 
 
 
 
 
 
 
 
1
  # import gradio as gr
 
 
2
  #
3
- # def greet(name):
4
- # return "Hello " + name + "!!"
 
5
  #
6
- # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  #
 
 
 
 
 
 
 
 
9
 
10
  import gradio as gr
11
- from sklearn.neighbors import KNeighborsClassifier
12
- import numpy as np
13
-
14
- # Training data
15
- X = np.array([[1, 2], [2, 3], [3, 1], [6, 5], [7, 7], [8, 6]])
16
- y = np.array([0, 0, 0, 1, 1, 1])
17
-
18
- # Training the model
19
- model = KNeighborsClassifier(n_neighbors=3)
20
- model.fit(X, y)
21
-
22
- # Define the prediction function
23
- def classify_point(x, y):
24
- prediction = model.predict([[x, y]])
25
- return "Class " + str(prediction[0])
26
-
27
- # Create a Gradio interface
28
- demo = gr.Interface(
29
- fn=classify_point,
30
- inputs=["number", "number"],
31
- outputs="text",
32
- description="Predict the class of a point based on its coordinates using K-Nearest Neighbors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
35
- # Launch the app
36
- demo.launch()
 
1
+ # # import gradio as gr
2
+ # #
3
+ # # def greet(name):
4
+ # # return "Hello " + name + "!!"
5
+ # #
6
+ # # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ # # demo.launch()
8
+ # #
9
+ #
10
  # import gradio as gr
11
+ # from sklearn.neighbors import KNeighborsClassifier
12
+ # import numpy as np
13
  #
14
+ # # Training data
15
+ # X = np.array([[1, 2], [2, 3], [3, 1], [6, 5], [7, 7], [8, 6]])
16
+ # y = np.array([0, 0, 0, 1, 1, 1])
17
  #
18
+ # # Training the model
19
+ # model = KNeighborsClassifier(n_neighbors=3)
20
+ # model.fit(X, y)
21
+ #
22
+ # # Define the prediction function
23
+ # def classify_point(x, y):
24
+ # prediction = model.predict([[x, y]])
25
+ # return "Class " + str(prediction[0])
26
+ #
27
+ # # Create a Gradio interface
28
+ # demo = gr.Interface(
29
+ # fn=classify_point,
30
+ # inputs=["number", "number"],
31
+ # outputs="text",
32
+ # description="Predict the class of a point based on its coordinates using K-Nearest Neighbors"
33
+ # )
34
  #
35
+ # # Launch the app
36
+ # demo.launch()
37
+
38
+ from dotenv import load_dotenv
39
+ import os
40
+ load_dotenv()
41
+
42
+ hf_token = os.getenv("HF_TOKEN")
43
 
44
  import gradio as gr
45
+ from peft import AutoPeftModelForCausalLM
46
+ from transformers import AutoTokenizer, BitsAndBytesConfig
47
+ import torch
48
+
49
+ bnb_config = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_use_double_quant=True,
52
+ bnb_4bit_quant_type='nf4',
53
+ bnb_4bit_compute_dtype=torch.bfloat16,
54
+ )
55
+
56
+ model = AutoPeftModelForCausalLM.from_pretrained(
57
+ 'pykale/llama-2-7b-ocr',
58
+ quantization_config=bnb_config,
59
+ low_cpu_mem_usage=True,
60
+ torch_dtype=torch.float16,
61
+ )
62
+ tokenizer = AutoTokenizer.from_pretrained('pykale/llama-2-7b-ocr', token=hf_token)
63
+
64
+ def fix_ocr_errors(ocr):
65
+ prompt = f"""### instruksi:
66
+ perbaiki kata yang salah pada hasil OCR, hasil perbaikan harus dalam bahasa indonesia.
67
+ ### Input:
68
+
69
+ {ocr}
70
+
71
+ ### Response:
72
+ """
73
+
74
+ input_ids = tokenizer(prompt, max_length=1024, return_tensors='pt', truncation=True).input_ids.cuda()
75
+ with torch.inference_mode():
76
+ outputs = model.generate(
77
+ input_ids=input_ids,
78
+ max_new_tokens=1024,
79
+ do_sample=True,
80
+ temperature=0.7,
81
+ top_p=0.1,
82
+ top_k=40
83
+ )
84
+ pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+ corrected_text = pred[len(prompt):].strip()
86
+ return corrected_text
87
+
88
+ iface = gr.Interface(
89
+ fn=fix_ocr_errors,
90
+ inputs=gr.Textbox(lines=5, placeholder="Masukkan teks OCR di sini..."),
91
+ outputs=gr.Textbox(label="text"),
92
+ title="Perbaiki Kesalahan OCR",
93
+ description="Masukkan teks dengan kesalahan OCR dan model akan mencoba memperbaikinya."
94
  )
95
 
96
+ if __name__ == "__main__":
97
+ iface.launch()
app.yml CHANGED
@@ -6,3 +6,4 @@ dependencies:
6
  - gradio
7
  - scikit-learn
8
  - torch
 
 
6
  - gradio
7
  - scikit-learn
8
  - torch
9
+ gpu: true
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  gradio
2
- scikit-learn
3
- torch
 
 
 
 
 
 
1
  gradio
2
+ scikit-learn~=1.3.2
3
+ torch
4
+ numpy~=1.24.4
5
+ transformers>=4.31.0
6
+ peft>=0.4.0
7
+ bitsandbytes>=0.39.0
8
+ python-dotenv