maubot-craiyon/craiyonbot.py

141 lines
5.2 KiB
Python
Raw Normal View History

2022-07-26 16:39:08 +02:00
from base64 import b64decode
from dataclasses import dataclass
from io import BytesIO
from typing import Tuple
from maubot import Plugin, MessageEvent
from maubot.handlers import command
from mautrix.types import ImageInfo, EventType, ReactionEvent
from PIL import Image
import magic
EMOJI_REGEX = r"^[\U00000031-\U00000039]\U0000FE0F\U000020E3"
EMOJI_NUMBERS = ["1", "2", "3", "4", "5", "6", "7", "8", "9"]
@dataclass
class IMMeta:
""" Lightweight wrapper for meta data for the images. Kept in memory """
uri: str
name: str
mime: str
size: int
height: int
width: int
class UploadableImage:
""" Handles the internals for images """
def __init__(self, name, image_bytes=None, pil=None):
if image_bytes is None:
with BytesIO() as output:
pil.save(output, format="PNG")
self.image_bytes = output.getvalue()
else:
self.image_bytes = image_bytes
self.name = name
self.mime = magic.from_buffer(self.image_bytes, mime=True)
self.pil = Image.open(BytesIO(image_bytes)) if pil is None else pil
self.size = len(self.image_bytes)
self.uri = None
async def upload(self, client: Plugin) -> str:
"""
uploads the contained image to the matrix network.
Sets it's own uri
"""
self.uri = await client.upload_media(self.image_bytes,
self.mime,
filename=self.name
)
return self.uri
def meta(self) -> IMMeta:
"""
Returns a lightweight version of the image (without the pixeldata),
to be stored in memory.
"""
return IMMeta(self.uri, self.name, self.mime, len(self.image_bytes),
self.pil.height, self.pil.width)
class CraiyonBot(Plugin):
"""
A Maubot plugin. It registers the command
!craiyon <text>
and then forwards all images in a 3 by 3 image matrix to the user.
Also gives the choice to select single images by using the emoji reactions
The mxc references are only stored at runtime and forgotten, once the bot
restarts.
"""
images = {}
def emoji_to_number(self, emoji: str) -> int:
""" Helper function to convert the emojis to their index """
return EMOJI_NUMBERS.index(emoji)
@command.passive(regex=EMOJI_REGEX,
field=lambda evt: evt.content.relates_to.key,
event_type=EventType.REACTION,
msgtypes=None)
async def get_image(self, evt: ReactionEvent, _: Tuple[str]) -> None:
""" Handler when clicked a reaction. Returns the chosen image """
msg = evt.content.relates_to.event_id
if msg in self.images:
index = self.emoji_to_number(evt.content.relates_to.key)
image = self.images[msg][index]
await self.client.send_image(evt.room_id, url=image.uri,
file_name=image.name,
info=ImageInfo(mimetype=image.mime,
size=image.size,
width=image.width,
height=image.height))
@command.new()
@command.argument("prompt", pass_raw=True, required=True)
async def craiyon(self, evt: MessageEvent, prompt: str) -> None:
"""
Forwards the request to craiyon and returns the images craiyon created
"""
await evt.react("🤖")
response = await self.http.post(
'https://backend.craiyon.com/generate',
json={'prompt': f'{prompt}<br>'}
)
images = [
(n, UploadableImage(f"{prompt}_{n}.jpg", b64decode(image)))
for n, image in enumerate((await response.json())['images'])
]
images_canvas = Image.new(mode="RGB", size=(256*3, 256*3))
images_data = []
for index, image in images:
pos_x, pos_y = index % 3, index // 3
images_canvas.paste(
image.pil, (pos_x * 256, pos_y * 256)
)
await image.upload(self.client)
images_data.append(image.meta())
uploadable_3x3 = UploadableImage(
f"{prompt}_3x3.png", pil=images_canvas)
await uploadable_3x3.upload(self.client)
msg = await self.client.send_image(evt.room_id, url=uploadable_3x3.uri,
file_name=uploadable_3x3.name,
info=ImageInfo(
mimetype=uploadable_3x3.mime,
size=uploadable_3x3.size,
width=uploadable_3x3.pil.width,
height=uploadable_3x3.pil.height)
)
self.images[msg] = images_data
for emoji in EMOJI_NUMBERS:
await evt.client.react(evt.room_id, msg, emoji)