rag-tool / text_generation.py
Chris4K's picture
Update text_generation.py
e03f966
raw
history blame
846 Bytes
from transformers import Tool
from transformers import pipeline
class TextGenerationTool(Tool):
name = "text_generator"
description = (
"This is a tool for text generation. It takes a prompt as input and returns the generated text."
)
inputs = ["text"]
outputs = ["text"]
def __call__(self, prompt: str):
# Replace the following line with your text generation logic
#generated_text = f"Generated text based on the prompt: '{prompt}'"
# Initialize the text generation pipeline
text_generator = pipeline(model="bigcode/starcoder")
# Generate text based on a prompt
generated_text = text_generator(prompt, max_length=500, num_return_sequences=1, temperature=0.7)
# Print the generated text
print(generated_text)
return generated_text