Update my_model/tabs/results.py
Browse files- my_model/tabs/results.py +23 -17
my_model/tabs/results.py
CHANGED
@@ -59,23 +59,8 @@ class ResultDemonstrator(KBVQAEvaluator):
|
|
59 |
# Compute colors and labels for the plot (assuming these columns are already in the DataFrame)
|
60 |
d['color'] = d['vqa_score_13b_caption+detic'].apply(lambda x: 'green' if x == 1 else 'orange' if round(x, 2) == 0.67 else 'red')
|
61 |
d['label'] = d['vqa_score_13b_caption+detic'].apply(lambda x: 'Correct' if x == 1 else 'Partially Correct' if x == 0.67 else 'Incorrect')
|
62 |
-
|
63 |
-
# Creating the scatter plot
|
64 |
-
scatter_chart = alt.Chart(d).mark_circle(size=20).encode(
|
65 |
-
x=alt.X('index:Q', title='Index'), # Assuming 'index' is a column or can be created
|
66 |
-
y=alt.Y('token_counts:Q', title='Number of Tokens'),
|
67 |
-
color=alt.Color('color:N', legend=alt.Legend(title="VQA Score")),
|
68 |
-
tooltip=['index', 'token_counts', 'label']
|
69 |
-
).interactive()
|
70 |
-
|
71 |
-
# Display the chart
|
72 |
-
st.altair_chart(scatter_chart, use_container_width=True)
|
73 |
-
|
74 |
-
|
75 |
-
####################
|
76 |
-
scores = d['vqa_score_13b_caption+detic']
|
77 |
|
78 |
-
|
79 |
|
80 |
plt.figure(figsize=(10, 6))
|
81 |
# Create a scatter plot with smaller dots using the 's' parameter
|
@@ -94,4 +79,25 @@ class ResultDemonstrator(KBVQAEvaluator):
|
|
94 |
plt.ylabel('Number of Tokens')
|
95 |
|
96 |
# Display the plot
|
97 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# Compute colors and labels for the plot (assuming these columns are already in the DataFrame)
|
60 |
d['color'] = d['vqa_score_13b_caption+detic'].apply(lambda x: 'green' if x == 1 else 'orange' if round(x, 2) == 0.67 else 'red')
|
61 |
d['label'] = d['vqa_score_13b_caption+detic'].apply(lambda x: 'Correct' if x == 1 else 'Partially Correct' if x == 0.67 else 'Incorrect')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
# Define colors and labels for the legend
|
64 |
|
65 |
plt.figure(figsize=(10, 6))
|
66 |
# Create a scatter plot with smaller dots using the 's' parameter
|
|
|
79 |
plt.ylabel('Number of Tokens')
|
80 |
|
81 |
# Display the plot
|
82 |
+
plt.show()
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
# Creating the scatter plot
|
91 |
+
scatter_chart = alt.Chart(d).mark_circle(size=20).encode(
|
92 |
+
x=alt.X('index:Q', title='Index'), # Assuming 'index' is a column or can be created
|
93 |
+
y=alt.Y('token_counts:Q', title='Number of Tokens'),
|
94 |
+
color=alt.Color('color:N', legend=alt.Legend(title="VQA Score")),
|
95 |
+
tooltip=['index', 'token_counts', 'label']
|
96 |
+
).interactive()
|
97 |
+
|
98 |
+
# Display the chart
|
99 |
+
st.altair_chart(scatter_chart, use_container_width=True)
|
100 |
+
|
101 |
+
|
102 |
+
####################
|
103 |
+
|