File size: 827 Bytes
668dc85
891948c
9068dfa
 
2dc4a04
9068dfa
2f657d8
f8e14c1
2f657d8
9068dfa
 
ce5c55e
9068dfa
 
 
afbacf2
9068dfa
afbacf2
 
 
 
 
668dc85
79912d7
668dc85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import gradio as gr
import os
def func(query):
    try:
        os.system("pip install rwkvstic inquirer transformers torch jax")
        from rwkvstic.load import RWKV
        from rwkvstic.agnostic.backends import JAX
        model = RWKV("https://huggingface.co/BlinkDL/rwkv-4-pile-3b/resolve/main/RWKV-4-Pile-3B-Instruct-test1-20230124.pth",
                    mode=JAX)
        
        def generate(query):
            model.loadContext(newctx=query)
            output = model.forward(number=100)["output"]
            return output
    except Exception as e:
        return "Error [A]" + str(e)
    else:
        try:
            a = generate(query)
            return a
        except Exception as e:
            return "Error [B]" + str(e)

iface = gr.Interface(fn=func, inputs="text", outputs="text")
iface.launch()