test-qatool / pages /ImportImageFileEasyOcr.py
naotakigawa's picture
Upload 10 files
d83dcd1
raw
history blame
2.7 kB
import os
import streamlit as st
import common
import os
import pickle
import easyocr
from log import logger
from pathlib import Path
from llama_index import Document
common.check_login()
INDEX_NAME = os.environ["INDEX_NAME"]
PKL_NAME = os.environ["PKL_NAME"]
if "file_uploader_key" not in st.session_state:
st.session_state["file_uploader_key"] = 0
st.title("📝 ImportImageFileEasyOcr")
uploaded_file = st.file_uploader("Upload an article", type=("png", "jpg", "jpeg"),key=st.session_state["file_uploader_key"])
if st.button("import",use_container_width=True):
filepath = os.path.join('documents', os.path.basename( uploaded_file.name))
try:
with open(filepath, 'wb') as f:
f.write(uploaded_file.getvalue())
f.close()
logger.info(filepath)
reader = easyocr.Reader(['ja','en'], gpu=False) # this needs to run only once to load the model into memory
result = reader.readtext(filepath, detail = 0, paragraph=True)
text = ''.join(result)
#読み込む画像ファイルのパスを設定
# IMG_FILE_PATH = "ocrtest.jpg"
#言語ファイルのパスを環境変数に設定
# tessdata_dir = "D:\project\stylez\chatGPT\llamaindex-streamlit\llm-examples-main\traindata"
# os.environ["TESSDATA_PREFIX"] = tessdata_dir
#画像ファイルを開く
# image = Image.open(filepath)
# #画像をRGBモードに変換
# image = image.convert('RGB')
# #画像から文字列データを抽出
# logger.info("image")
# text = pytesseract.image_to_string(image, lang='jpn')
logger.info(text)
document = Document(text=text)
logger.info(document)
document.metadata={'filename': os.path.basename(uploaded_file.name)}
st.session_state.stored_docs.append(uploaded_file.name)
logger.info(st.session_state.stored_docs)
st.session_state.index.insert(document=document)
st.session_state.index.storage_context.persist(persist_dir=INDEX_NAME)
os.remove(filepath)
common.setChatEngine()
with open(PKL_NAME, "wb") as f:
print("pickle")
pickle.dump(st.session_state.stored_docs, f)
st.session_state["file_uploader_key"] += 1
st.experimental_rerun()
except Exception as e:
# cleanup temp file
logger.error(e)
if filepath is not None and os.path.exists(filepath):
os.remove(filepath)
st.subheader("Import File List")
if "stored_docs" in st.session_state:
logger.info(st.session_state.stored_docs)
for docname in st.session_state.stored_docs:
st.write(docname)