140 lines
5.2 KiB
Python
140 lines
5.2 KiB
Python
from base64 import b64decode
|
||
from dataclasses import dataclass
|
||
from io import BytesIO
|
||
from typing import Tuple
|
||
|
||
from maubot import Plugin, MessageEvent
|
||
from maubot.handlers import command
|
||
from mautrix.types import ImageInfo, EventType, ReactionEvent
|
||
|
||
from PIL import Image
|
||
import magic
|
||
|
||
EMOJI_REGEX = r"^[\U00000031-\U00000039]\U0000FE0F\U000020E3"
|
||
EMOJI_NUMBERS = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣"]
|
||
|
||
|
||
@dataclass
|
||
class IMMeta:
|
||
""" Lightweight wrapper for meta data for the images. Kept in memory """
|
||
uri: str
|
||
name: str
|
||
mime: str
|
||
size: int
|
||
height: int
|
||
width: int
|
||
|
||
|
||
class UploadableImage:
|
||
""" Handles the internals for images """
|
||
def __init__(self, name, image_bytes=None, pil=None):
|
||
if image_bytes is None:
|
||
with BytesIO() as output:
|
||
pil.save(output, format="PNG")
|
||
self.image_bytes = output.getvalue()
|
||
else:
|
||
self.image_bytes = image_bytes
|
||
self.name = name
|
||
self.mime = magic.from_buffer(self.image_bytes, mime=True)
|
||
self.pil = Image.open(BytesIO(image_bytes)) if pil is None else pil
|
||
self.size = len(self.image_bytes)
|
||
self.uri = None
|
||
|
||
async def upload(self, client: Plugin) -> str:
|
||
"""
|
||
uploads the contained image to the matrix network.
|
||
Sets it's own uri
|
||
"""
|
||
self.uri = await client.upload_media(self.image_bytes,
|
||
self.mime,
|
||
filename=self.name
|
||
)
|
||
return self.uri
|
||
|
||
def meta(self) -> IMMeta:
|
||
"""
|
||
Returns a lightweight version of the image (without the pixeldata),
|
||
to be stored in memory.
|
||
"""
|
||
return IMMeta(self.uri, self.name, self.mime, len(self.image_bytes),
|
||
self.pil.height, self.pil.width)
|
||
|
||
|
||
class CraiyonBot(Plugin):
|
||
"""
|
||
A Maubot plugin. It registers the command
|
||
|
||
!craiyon <text>
|
||
|
||
and then forwards all images in a 3 by 3 image matrix to the user.
|
||
Also gives the choice to select single images by using the emoji reactions
|
||
|
||
The mxc references are only stored at runtime and forgotten, once the bot
|
||
restarts.
|
||
"""
|
||
images = {}
|
||
|
||
def emoji_to_number(self, emoji: str) -> int:
|
||
""" Helper function to convert the emojis to their index """
|
||
return EMOJI_NUMBERS.index(emoji)
|
||
|
||
@command.passive(regex=EMOJI_REGEX,
|
||
field=lambda evt: evt.content.relates_to.key,
|
||
event_type=EventType.REACTION,
|
||
msgtypes=None)
|
||
async def get_image(self, evt: ReactionEvent, _: Tuple[str]) -> None:
|
||
""" Handler when clicked a reaction. Returns the chosen image """
|
||
msg = evt.content.relates_to.event_id
|
||
if msg in self.images:
|
||
index = self.emoji_to_number(evt.content.relates_to.key)
|
||
image = self.images[msg][index]
|
||
await self.client.send_image(evt.room_id, url=image.uri,
|
||
file_name=image.name,
|
||
info=ImageInfo(mimetype=image.mime,
|
||
size=image.size,
|
||
width=image.width,
|
||
height=image.height))
|
||
|
||
@command.new()
|
||
@command.argument("prompt", pass_raw=True, required=True)
|
||
async def craiyon(self, evt: MessageEvent, prompt: str) -> None:
|
||
"""
|
||
Forwards the request to craiyon and returns the images craiyon created
|
||
"""
|
||
await evt.react("🤖")
|
||
response = await self.http.post(
|
||
'https://backend.craiyon.com/generate',
|
||
json={'prompt': f'{prompt}<br>'}
|
||
)
|
||
|
||
images = [
|
||
(n, UploadableImage(f"{prompt}_{n}.jpg", b64decode(image)))
|
||
for n, image in enumerate((await response.json())['images'])
|
||
]
|
||
|
||
images_canvas = Image.new(mode="RGB", size=(256*3, 256*3))
|
||
images_data = []
|
||
|
||
for index, image in images:
|
||
pos_x, pos_y = index % 3, index // 3
|
||
images_canvas.paste(
|
||
image.pil, (pos_x * 256, pos_y * 256)
|
||
)
|
||
await image.upload(self.client)
|
||
images_data.append(image.meta())
|
||
|
||
uploadable_3x3 = UploadableImage(
|
||
f"{prompt}_3x3.png", pil=images_canvas)
|
||
await uploadable_3x3.upload(self.client)
|
||
|
||
msg = await self.client.send_image(evt.room_id, url=uploadable_3x3.uri,
|
||
file_name=uploadable_3x3.name,
|
||
info=ImageInfo(
|
||
mimetype=uploadable_3x3.mime,
|
||
size=uploadable_3x3.size,
|
||
width=uploadable_3x3.pil.width,
|
||
height=uploadable_3x3.pil.height)
|
||
)
|
||
self.images[msg] = images_data
|
||
for emoji in EMOJI_NUMBERS:
|
||
await evt.client.react(evt.room_id, msg, emoji)
|