Food / app.py
Noracle's picture
Update app.py
c25f19c verified
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch
# 缓存模型加载,避免重复加载
@st.cache_resource
def load_pipeline1_model():
model_name = "yitongwu73/finetuned-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
pipe = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1 # 使用GPU(如有)
)
return pipe
@st.cache_resource
def load_pipeline2_model():
model_name = "Noracle/finetuned-distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1 # 使用GPU(如有)
)
return pipe
def main():
st.set_page_config(
page_title="IELTS Essay Scoring & Feedback System",
page_icon="📝",
layout="wide"
)
st.title("NewChannel IELTS Essay Evaluation System")
st.markdown("### Please enter your essay prompt and content")
# 输入框布局
with st.form("evaluation_form"):
prompt = st.text_area("Essay Prompt", height=100)
essay = st.text_area("Essay Content", height=300)
submitted = st.form_submit_button("Generate Evaluation Report")
if submitted:
if not prompt or not essay:
st.warning("Please complete both the essay prompt and content")
return
# 显示加载状态
with st.spinner("Evaluating your essay..."):
# Pipeline 1: 获取整体评分
pipe1 = load_pipeline1_model()
input_text = f"Scoring Task: Prompt: {prompt} Essay: {essay}"
score_output = pipe1(input_text)[0]
# 分数映射和转换(保持与之前相同的逻辑)
logits = score_output['score']
min_logit, max_logit = -5, 5 # 需根据模型实际输出范围调整
min_score, max_score = 3.5, 9.0
overall_score = ((logits - min_logit) / (max_logit - min_logit)) * (max_score - min_score) + min_score
overall_score = round(overall_score * 2) / 2 # 保留0.5间隔
# Pipeline 2: 生成详细反馈
pipe2 = load_pipeline2_model()
input_text2 = f"Generate Feedback: Prompt: {prompt} Essay: {essay}"
feedback_output = pipe2(input_text2, max_length=1024, num_return_sequences=1)[0]
feedback_text = feedback_output['generated_text']
# 展示结果(英文)
st.markdown("### Evaluation Results")
st.markdown(f"#### Overall Band Score: **{overall_score}/9.0**")
st.markdown("#### Detailed Feedback")
st.markdown(feedback_text)
if __name__ == "__main__":
main()