|
from transformers import Tool |
|
from transformers import pipeline |
|
|
|
class TextGenerationTool(Tool): |
|
name = "text_generator" |
|
description = ( |
|
"This is a tool for text generation. It takes a prompt as input and returns the generated text." |
|
) |
|
|
|
inputs = ["text"] |
|
outputs = ["text"] |
|
|
|
def __call__(self, prompt: str): |
|
|
|
|
|
|
|
|
|
|
|
text_generator = pipeline(model="mistralai/Mistral-7B-Instruct-v0.1") |
|
|
|
|
|
generated_text = text_generator(prompt, max_length=500, num_return_sequences=1, temperature=0.7) |
|
|
|
|
|
print(generated_text) |
|
|
|
|
|
|
|
return generated_text |
|
|
|
|
|
|