File size: 4,852 Bytes
40190c3
41d1bc5
 
 
 
45b4c77
41d1bc5
 
 
 
 
 
 
 
 
 
 
def8b72
 
 
01d0766
41d1bc5
229e7eb
41d1bc5
 
229e7eb
41d1bc5
 
 
 
 
 
 
 
 
229e7eb
41d1bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f38260
41d1bc5
 
 
 
01d0766
 
 
41d1bc5
 
 
 
 
 
 
 
 
ecf217a
 
41d1bc5
 
 
 
 
 
40190c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import openai
import os
import sys
import argparse
sys.path.append('./lats')
from lats_main import lats_main

st.set_page_config(layout="wide")

# Initialize session state variables if they don't exist.
if 'response_content' not in st.session_state:
    st.session_state.response_content = None

# Creating main columns for the chat and runtime notifications
chat_col = st.container()

chat_col.title("SambaLATS")
description = """This demo is an implementation of Language Agent Tree Search (LATS) (https://arxiv.org/abs/2310.04406) with Samba-1 in the backend. Thank you to the original authors of demo on which this is based from [Lapis Labs](https://lapis.rocks/)!

Given Samba-1's lightning quick inference, not only can we accelerate our system's speeds but also improve our system's accuracy. Using many inference calls in this LATS style, we can solve programming questions with higher accuracy. In fact, this system reaches **GPT-3.5 accuracy on HumanEval Python**, 74% accuracy, with LLaMa 3 8B, taking 8 seconds on average. This is a 15.5% boost on LLaMa 3 8B alone. 

Listed below is an example programming problem (https://leetcode.com/problems/median-of-two-sorted-arrays/description/) to get started with. 

```python
Given two sorted arrays `nums1` and `nums2` of size `m` and `n` respectively, return **the median** of the two sorted arrays. The overall run time complexity should be `O(log (m+n))`. **Example 1:** **Input:** nums1 = \[1,3\], nums2 = \[2\] **Output:** 2.00000 **Explanation:** merged array = \[1,2,3\] and median is 2. **Example 2:** **Input:** nums1 = \[1,2\], nums2 = \[3,4\] **Output:** 2.50000 **Explanation:** merged array = \[1,2,3,4\] and median is (2 + 3) / 2 = 2.5. **Constraints:** * `nums1.length == m` * `nums2.length == n` * `0 <= m <= 1000` * `0 <= n <= 1000` * `1 <= m + n <= 2000` * `-106 <= nums1[i], nums2[i] <= 106`
```
"""

chat_col.markdown(description)
sidebar = st.sidebar
# Runtime Section
runtime_container = st.container()

# Parameters Section
sidebar.title("From SambaNova Systems")
parameters_section = sidebar.expander("Parameters", expanded=False)
tree_width = parameters_section.number_input("Tree Width", min_value=1, max_value=5, value=1)
tree_depth = parameters_section.number_input("Tree Depth", min_value=1, max_value=8, value=3)
iterations = parameters_section.number_input("Iterations", min_value=1, max_value=4, value=2)
sidebar.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True)

with sidebar:
    runtime_container = st.container()
    runtime_container.empty()

runtime_messages = []

def make_args(instruction, tree_depth, tree_width, iterations):
    parser = argparse.ArgumentParser()

    parser.add_argument("--strategy", default="mcts", help="Strategy to use")
    parser.add_argument("--language", default="py", help="Programming language")
    parser.add_argument("--max_iters", default=iterations, help="Maximum iterations")
    parser.add_argument("--instruction", default=instruction, help="Instruction text")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    parser.add_argument("--is_leetcode", action='store_true',
                        help="To run the leetcode benchmark")  # Temporary
    parser.add_argument("--n_samples", type=int,
                        help="The number of nodes added during expansion", default=tree_width)
    parser.add_argument("--depth", type=int,
                        help="Tree depth", default=tree_depth)
    args = parser.parse_args()
    return args

def run_querry():
    if user_input:
        # Create a new container for each subsequent message
        runtime_container.write("Initiating process...")

        # Make it so that prints go to runtime_container writes instead
        old_stdout = sys.stdout
        sys.stdout = runtime_container

        with chat_col:

            with st.spinner('Running...'):
                args = make_args(user_input, tree_depth, tree_width, iterations)
                setattr(args, 'model', 'samba')
                # main call
                response = lats_main(args)

        sys.stdout = old_stdout
        # runtime_container.write("Response fetched.")
        # chat_col.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True)
        # chat_col.write(f"```python\n{response} \n")

        return response

# User input section at the bottom of the page
with chat_col:
    user_input = st.text_area("Enter your message here:", placeholder="Type your message here...", label_visibility="collapsed")
    button = st.button("Send")

    if button:
        fail = False
    
        if user_input == "":
            st.warning("Missing a coding problem")
            fail = True
        
        if (not fail):
            run_querry()