hanxu22
commited on
Commit
·
1f16925
1
Parent(s):
2fc0794
add rerank test
Browse files
README.md
CHANGED
@@ -11,3 +11,11 @@ license: apache-2.0
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
15 |
+
|
16 |
+
# usage
|
17 |
+
```
|
18 |
+
git add .
|
19 |
+
git commit -m "comment"
|
20 |
+
git push
|
21 |
+
```
|
app.py
CHANGED
@@ -1,4 +1,30 @@
|
|
1 |
-
import
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
+
import torch
|
3 |
|
4 |
+
# 加载rerank模型和tokenizer
|
5 |
+
model_name = "BAAI/bge-reranker-v2-m3" # 替换为你的rerank模型名称
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
8 |
+
|
9 |
+
# 定义候选项和查询
|
10 |
+
query = "What is the capital of France?"
|
11 |
+
candidates = [
|
12 |
+
"Paris is the capital of France.",
|
13 |
+
"Berlin is the capital of Germany.",
|
14 |
+
"Madrid is the capital of Spain."
|
15 |
+
]
|
16 |
+
|
17 |
+
# 对每个候选项进行打分
|
18 |
+
scores = []
|
19 |
+
for candidate in candidates:
|
20 |
+
inputs = tokenizer(query, candidate, return_tensors="pt", truncation=True)
|
21 |
+
with torch.no_grad():
|
22 |
+
logits = model(**inputs).logits
|
23 |
+
scores.append(logits.item())
|
24 |
+
|
25 |
+
# 根据分数对候选项重新排序
|
26 |
+
ranked_candidates = [x for _, x in sorted(zip(scores, candidates), reverse=True)]
|
27 |
+
|
28 |
+
# 输出排序结果
|
29 |
+
for i, candidate in enumerate(ranked_candidates):
|
30 |
+
print(f"Rank {i + 1}: {candidate} (Score: {scores[i]})")
|
demo.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import base64
|
3 |
+
|
4 |
+
from streamlit.components.v1 import html
|
5 |
+
|
6 |
+
st.title("Video Player App")
|
7 |
+
|
8 |
+
video_file_path = "D:\\huggingface\\KHome\\digitalhuman.mp4"
|
9 |
+
# 读取视频文件的内容
|
10 |
+
with open(video_file_path, "rb") as video_file:
|
11 |
+
video_bytes = video_file.read()
|
12 |
+
video_base64 = base64.b64encode(video_bytes).decode('utf-8')
|
13 |
+
|
14 |
+
STOPPED_STATE = f"""
|
15 |
+
<video id="videoPlayer" width="320" height="240" controls>
|
16 |
+
<source src="data:video/mp4;base64,{video_base64}" type="video/mp4">
|
17 |
+
Your browser does not support the video tag.
|
18 |
+
</video>
|
19 |
+
"""
|
20 |
+
|
21 |
+
PLAYING_STATE = f"""
|
22 |
+
<video id="videoPlayer" width="320" height="240" controls autoplay>
|
23 |
+
<source src="data:video/mp4;base64,{video_base64}" type="video/mp4">
|
24 |
+
Your browser does not support the video tag.
|
25 |
+
</video>
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
playstate = st.session_state.get("playstate")
|
30 |
+
if not playstate:
|
31 |
+
playstate = STOPPED_STATE
|
32 |
+
st.markdown(playstate, unsafe_allow_html=True)
|
33 |
+
|
34 |
+
# 播放按钮
|
35 |
+
if st.button("Play"):
|
36 |
+
st.session_state["playstate"] = PLAYING_STATE
|
37 |
+
|
38 |
+
# 暂停按钮
|
39 |
+
if st.button("Pause"):
|
40 |
+
st.session_state["playstate"] = STOPPED_STATE
|
stablediffusion3.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
|
3 |
+
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium")
|