RAG01 / test_workflow.py
NaderAfshar
updated code and implemented a new test: test_workflow
55b7d0c
from llama_index.core.workflow import (
StartEvent,
StopEvent,
Workflow,
step,
Event,
Context,
)
import asyncio
import nest_asyncio
from llama_index.llms.groq import Groq
from llama_index.utils.workflow import draw_all_possible_flows
from IPython.display import display, HTML
from dotenv import load_dotenv
from helper import extract_html_content
from pathlib import Path
import os
nest_asyncio.apply()
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
global_llm = Groq(api_key=GROQ_API_KEY, model="llama3-70b-8192")
class FirstEvent(Event):
first_output: str
class SecondEvent(Event):
second_output: str
response: str
class ProgressEvent(Event):
msg: str
class MyWorkflow(Workflow):
@step
async def step_one(self, ctx: Context, ev: StartEvent) -> FirstEvent:
ctx.write_event_to_stream(ProgressEvent(msg="Step one is happening"))
return FirstEvent(first_output="First step complete.")
@step
async def step_two(self, ctx: Context, ev: FirstEvent) -> SecondEvent:
llm = global_llm
generator = await llm.astream_complete(
"Please give me the first 3 paragraphs of Moby Dick, a book in the public domain."
)
async for response in generator:
# Allow the workflow to stream this piece of response
ctx.write_event_to_stream(ProgressEvent(msg=response.delta))
return SecondEvent(
second_output="Second step complete, full response attached",
response=str(response),
)
@step
async def step_three(self, ctx: Context, ev: SecondEvent) -> StopEvent:
ctx.write_event_to_stream(ProgressEvent(msg="Step three is happening"))
return StopEvent(result="Workflow complete.")
async def main():
w = MyWorkflow(timeout=30, verbose=True)
handler = w.run(first_input="Start the workflow.")
async for ev in handler.stream_events():
if isinstance(ev, ProgressEvent):
print(ev.msg)
final_result = await handler
print("Final result", final_result)
workflow_file = Path(__file__).parent / "workflows" / "streaming_workflow.html"
draw_all_possible_flows(w, filename=str(workflow_file))
html_content = extract_html_content(str(workflow_file))
display(HTML(html_content), metadata=dict(isolated=True))
if __name__ == "__main__":
asyncio.run(main())