geminibot/geminibot.ts
2024-07-15 19:16:36 +02:00

187 lines
10 KiB
TypeScript

import { createRestAPIClient, createStreamingAPIClient } from 'masto';
import { GoogleGenerativeAI } from '@google/generative-ai';
import { Client } from '@elastic/elasticsearch';
import Jimp from "jimp";
import * as config from './config.js';
(async () => {
try {
const index = config.nick,
nick = config.index,
rest = createRestAPIClient({
url: `https://${config.url}`,
accessToken: config.accessToken,
requestInit: {
headers: [
['User-Agent', nick]
]
}
}),
ws = createStreamingAPIClient({
streamingApiUrl: `wss://${config.url}`,
accessToken: config.accessToken,
retry: true
}),
genAI = new GoogleGenerativeAI(config.apiKey),
model = genAI.getGenerativeModel({ model: config.model }),
client = new Client({ node: config.elasticNode }),
initialize = config.initialize,
getMediaAttachments = async (status) => {
return (await Promise.all(status.mediaAttachments.map(async (attachment: { url: string | URL | Request; type: string; }) => {
const attach = await fetch(attachment.url)
if (attach.ok) {
let body = await attach.arrayBuffer(),
resized: Buffer
if (attachment.type === 'image' && (attachment.url.toString().toLowerCase().endsWith('.png') ||
attachment.url.toString().toLowerCase().match(/.*\.(jpe?g)$/))) {
const image = await Jimp.read(Buffer.from(body))
if (image.getWidth() > 512) {
if (image.getMIME() === 'image/jpeg') {
resized = await image.resize(512, Jimp.AUTO).getBufferAsync(Jimp.MIME_JPEG)
} else {
resized = await image.resize(512, Jimp.AUTO).getBufferAsync(Jimp.MIME_PNG)
}
}
}
return {
inlineData: {
data: Buffer.from(resized || body).toString('base64'),
mimeType: attachment.type === 'image' ? attachment.url.toString().toLowerCase().endsWith('.png') ?
'image/png' : attachment.url.toString().toLowerCase().endsWith('.gif') ?
'image/gif' : 'image/jpeg' : attachment.type === 'video' ?
'video/mp4' : attachment.type === 'audio' ?
'audio/mpeg' : 'application/octet-stream'
}
}
}
}))).filter(attach => attach?.inlineData?.data)
}
for await (const notification of ws.user.notification.subscribe()) {
if (notification.payload['type'] === 'mention' && notification.payload['status']) {
const status = notification.payload['status']
if (status && status.visibility !== 'direct' && status.account.acct !== nick && status.content.trim() !== '') {
try {
const query = await client.search({
index,
size: 100,
sort: [{ timestamp: { order: 'desc' } }],
query: {
term: {
user: status.account.acct
}
}
}),
hist: any = query.hits && query.hits.hits && query.hits.hits.length > 0 ? query.hits.hits : null
let history = [],
entries = []
if (hist && hist.length > 0) {
history = hist.map((h: any) => h._source).reduce((acc, h) => h && h.entry && h.entry.length > 0 ?
acc.concat(h.entry.sort((a, b) => a.role < b.role ? 1 : -1)) : acc, [])
}
const content = ('' + status.content).replace(/<[^>]*>?/gm, '').trim()
let parts
if (content.match(new RegExp('^@' + nick + ' .* hilo.*$')) && status.inReplyToId) {
let id: string, s = await rest.v1.statuses.$select(status.id).fetch()
do {
id = s.inReplyToId
s = await rest.v1.statuses.$select(id).fetch()
} while (!!s && s.inReplyToId)
const descendants = [...(await rest.v1.statuses.$select(s.id).context.fetch()).descendants.reverse(), s]
await Promise.all(descendants.map(async s => {
if (s.id !== status.id) {
parts = []
if (s.mediaAttachments.length > 0) {
const imageParts = await getMediaAttachments(s)
if (imageParts.length > 0) {
parts = parts.concat(imageParts)
}
}
entries.push({
role: s.account.acct === nick ? 'model' : 'user', parts: parts.length > 0 ?
[{ text: `${s.account.acct}: ${s.content}` }, ...parts] : [{ text: `${s.account.acct}: ${s.content}` }]
})
}
}))
}
parts = []
const question = `${status.account.acct}: ${status.content}`
if (status.mediaAttachments.length > 0) {
const imageParts = await getMediaAttachments(status)
if (imageParts.length > 0) {
parts = parts.concat(imageParts)
}
}
while (new Blob([JSON.stringify([...history, ...entries])]).size >= config.maxSize) {
if (history.length === 0) {
entries.shift()
} else {
history.shift()
}
}
const chat = model.startChat({
history: history.length > 0 ?
entries.length > 0 ?
[...history, ...entries] :
history :
[{ role: 'user', parts: [{ text: initialize }] }],
generationConfig: { maxOutputTokens: 220, temperature: 0.7, topP: 0.95, topK: 40 }
}),
result = await chat.sendMessage(parts.length > 0 ? [{ text: question }, ...parts] : [{ text: question }]),
answer = result.response.text()
if (answer && answer.length > 0) {
let statusResponse
for (let i = 0; i < answer.length; i = i + 500) {
statusResponse = await rest.v1.statuses.create({
status: answer.slice(i, i + 500),
sensitive: false,
visibility: 'public',
inReplyToId: statusResponse && statusResponse.id ? statusResponse.id : status.id,
language: status.language
})
}
const user = {
role: 'user', parts: parts.length > 0 ? [{ text: question }, ...parts.filter(p =>
p.text || p.inlineData?.mimeType === 'image/png' || p.inlineData?.mimeType === 'image/jpeg')]
: [{ text: question }]
},
bot = {
role: 'model', parts: [{ text: answer }]
}
await client.index({
index,
body: {
user: status.account.acct,
entry: entries.length > 0 ? [user].concat(entries.map(entry => ({
role: entry.role, parts: entry.parts.filter((p: { text: any; inlineData: { mimeType: string; }; }) =>
p.text || p.inlineData?.mimeType === 'image/png' || p.inlineData?.mimeType === 'image/jpeg')
}))).concat([bot]) : [user, bot],
timestamp: new Date()
}
})
} else {
await rest.v1.statuses.create({
status: 'No response.',
sensitive: false,
visibility: 'public',
inReplyToId: status.id,
language: status.language
})
}
} catch (e) {
await rest.v1.statuses.create({
status: e.message,
sensitive: false,
visibility: 'public',
inReplyToId: status.id,
language: status.language
})
}
}
}
}
} catch (e) {
console.error(e)
process.exit(1)
}
})()