resumescan / bertimproved.py
sanjay11's picture
Create bertimproved.py
e778d13
raw history blame
No virus
2.45 kB
import streamlit as st
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
from io import BytesIO
import PyPDF2
import pandas as pd
# Initialize session state to store the log of QA pairs and satisfaction responses
if 'qa_log' not in st.session_state:
st.session_state.qa_log = []
def extract_text_from_pdf(pdf_file):
pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_file.read()))
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
def answer_question(question, context, model, tokenizer):
inputs = tokenizer.encode_plus(
question,
context,
add_special_tokens=True,
return_tensors="pt",
truncation="only_second",
max_length=512,
)
outputs = model(**inputs, return_dict=True)
answer_start_scores = outputs.start_logits
answer_end_scores = outputs.end_logits
answer_start = torch.argmax(answer_start_scores)
answer_end = torch.argmax(answer_end_scores) + 1
input_ids = inputs["input_ids"].tolist()[0]
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])
)
return answer
st.title("Resume Question Answering")
uploaded_file = st.file_uploader("Upload your resume (PDF format only)", type=["pdf"])
if uploaded_file is not None:
resume_text = extract_text_from_pdf(uploaded_file)
st.write("Resume Text:")
st.write(resume_text)
user_question = st.text_input("Ask a question based on your resume:")
if user_question:
model = BertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
answer = answer_question(user_question, resume_text, model, tokenizer)
st.write("Answer:")
st.write(answer)
# Ask for user feedback on satisfaction
satisfaction = st.radio('Are you satisfied with the answer?', ('Yes', 'No'), key='satisfaction')
# Log the interaction
st.session_state.qa_log.append({
'Question': user_question,
'Answer': answer,
'Satisfaction': satisfaction
})
# Display the log in a table format
st.write("Interaction Log:")
log_df = pd.DataFrame(st.session_state.qa_log)
st.dataframe(log_df)