141 lines
5.2 KiB
Python
141 lines
5.2 KiB
Python
|
from base64 import b64decode
|
|||
|
from dataclasses import dataclass
|
|||
|
from io import BytesIO
|
|||
|
from typing import Tuple
|
|||
|
|
|||
|
from maubot import Plugin, MessageEvent
|
|||
|
from maubot.handlers import command
|
|||
|
from mautrix.types import ImageInfo, EventType, ReactionEvent
|
|||
|
|
|||
|
from PIL import Image
|
|||
|
import magic
|
|||
|
|
|||
|
EMOJI_REGEX = r"^[\U00000031-\U00000039]\U0000FE0F\U000020E3"
|
|||
|
EMOJI_NUMBERS = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣"]
|
|||
|
|
|||
|
|
|||
|
@dataclass
|
|||
|
class IMMeta:
|
|||
|
""" Lightweight wrapper for meta data for the images. Kept in memory """
|
|||
|
uri: str
|
|||
|
name: str
|
|||
|
mime: str
|
|||
|
size: int
|
|||
|
height: int
|
|||
|
width: int
|
|||
|
|
|||
|
|
|||
|
class UploadableImage:
|
|||
|
""" Handles the internals for images """
|
|||
|
def __init__(self, name, image_bytes=None, pil=None):
|
|||
|
if image_bytes is None:
|
|||
|
with BytesIO() as output:
|
|||
|
pil.save(output, format="PNG")
|
|||
|
self.image_bytes = output.getvalue()
|
|||
|
else:
|
|||
|
self.image_bytes = image_bytes
|
|||
|
self.name = name
|
|||
|
self.mime = magic.from_buffer(self.image_bytes, mime=True)
|
|||
|
self.pil = Image.open(BytesIO(image_bytes)) if pil is None else pil
|
|||
|
self.size = len(self.image_bytes)
|
|||
|
self.uri = None
|
|||
|
|
|||
|
async def upload(self, client: Plugin) -> str:
|
|||
|
"""
|
|||
|
uploads the contained image to the matrix network.
|
|||
|
Sets it's own uri
|
|||
|
"""
|
|||
|
self.uri = await client.upload_media(self.image_bytes,
|
|||
|
self.mime,
|
|||
|
filename=self.name
|
|||
|
)
|
|||
|
return self.uri
|
|||
|
|
|||
|
def meta(self) -> IMMeta:
|
|||
|
"""
|
|||
|
Returns a lightweight version of the image (without the pixeldata),
|
|||
|
to be stored in memory.
|
|||
|
"""
|
|||
|
return IMMeta(self.uri, self.name, self.mime, len(self.image_bytes),
|
|||
|
self.pil.height, self.pil.width)
|
|||
|
|
|||
|
|
|||
|
class CraiyonBot(Plugin):
|
|||
|
"""
|
|||
|
A Maubot plugin. It registers the command
|
|||
|
|
|||
|
!craiyon <text>
|
|||
|
|
|||
|
and then forwards all images in a 3 by 3 image matrix to the user.
|
|||
|
Also gives the choice to select single images by using the emoji reactions
|
|||
|
|
|||
|
The mxc references are only stored at runtime and forgotten, once the bot
|
|||
|
restarts.
|
|||
|
"""
|
|||
|
images = {}
|
|||
|
|
|||
|
def emoji_to_number(self, emoji: str) -> int:
|
|||
|
""" Helper function to convert the emojis to their index """
|
|||
|
return EMOJI_NUMBERS.index(emoji)
|
|||
|
|
|||
|
@command.passive(regex=EMOJI_REGEX,
|
|||
|
field=lambda evt: evt.content.relates_to.key,
|
|||
|
event_type=EventType.REACTION,
|
|||
|
msgtypes=None)
|
|||
|
async def get_image(self, evt: ReactionEvent, _: Tuple[str]) -> None:
|
|||
|
""" Handler when clicked a reaction. Returns the chosen image """
|
|||
|
msg = evt.content.relates_to.event_id
|
|||
|
if msg in self.images:
|
|||
|
index = self.emoji_to_number(evt.content.relates_to.key)
|
|||
|
image = self.images[msg][index]
|
|||
|
await self.client.send_image(evt.room_id, url=image.uri,
|
|||
|
file_name=image.name,
|
|||
|
info=ImageInfo(mimetype=image.mime,
|
|||
|
size=image.size,
|
|||
|
width=image.width,
|
|||
|
height=image.height))
|
|||
|
|
|||
|
@command.new()
|
|||
|
@command.argument("prompt", pass_raw=True, required=True)
|
|||
|
async def craiyon(self, evt: MessageEvent, prompt: str) -> None:
|
|||
|
"""
|
|||
|
Forwards the request to craiyon and returns the images craiyon created
|
|||
|
"""
|
|||
|
await evt.react("🤖")
|
|||
|
response = await self.http.post(
|
|||
|
'https://backend.craiyon.com/generate',
|
|||
|
json={'prompt': f'{prompt}<br>'}
|
|||
|
)
|
|||
|
|
|||
|
images = [
|
|||
|
(n, UploadableImage(f"{prompt}_{n}.jpg", b64decode(image)))
|
|||
|
for n, image in enumerate((await response.json())['images'])
|
|||
|
]
|
|||
|
|
|||
|
images_canvas = Image.new(mode="RGB", size=(256*3, 256*3))
|
|||
|
images_data = []
|
|||
|
|
|||
|
for index, image in images:
|
|||
|
pos_x, pos_y = index % 3, index // 3
|
|||
|
images_canvas.paste(
|
|||
|
image.pil, (pos_x * 256, pos_y * 256)
|
|||
|
)
|
|||
|
await image.upload(self.client)
|
|||
|
images_data.append(image.meta())
|
|||
|
|
|||
|
uploadable_3x3 = UploadableImage(
|
|||
|
f"{prompt}_3x3.png", pil=images_canvas)
|
|||
|
await uploadable_3x3.upload(self.client)
|
|||
|
|
|||
|
msg = await self.client.send_image(evt.room_id, url=uploadable_3x3.uri,
|
|||
|
file_name=uploadable_3x3.name,
|
|||
|
info=ImageInfo(
|
|||
|
mimetype=uploadable_3x3.mime,
|
|||
|
size=uploadable_3x3.size,
|
|||
|
width=uploadable_3x3.pil.width,
|
|||
|
height=uploadable_3x3.pil.height)
|
|||
|
)
|
|||
|
self.images[msg] = images_data
|
|||
|
for emoji in EMOJI_NUMBERS:
|
|||
|
await evt.client.react(evt.room_id, msg, emoji)
|