parent
c7a0522ee1
commit
1bee375d29
@ -1,40 +1,223 @@
|
||||
mod openai;
|
||||
use openai::{ChatCompletion, Configuration, Message, OpenAiApiClient, Role};
|
||||
use std::env;
|
||||
use futures::StreamExt;
|
||||
use serenity::{
|
||||
async_trait,
|
||||
model::{channel::Message, gateway::Ready, prelude::*},
|
||||
prelude::*,
|
||||
utils::Color,
|
||||
};
|
||||
use std::{env, time::Duration};
|
||||
|
||||
struct Handler {
|
||||
channel: ChannelId,
|
||||
client: openai::OpenAiApiClient,
|
||||
reset_time: tokio::sync::RwLock<time::OffsetDateTime>,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EventHandler for Handler {
|
||||
async fn message(&self, ctx: Context, msg: Message) {
|
||||
if msg.content == "\\reset" {
|
||||
let mut reset_time = self.reset_time.write().await;
|
||||
*reset_time = time::OffsetDateTime::now_utc();
|
||||
}
|
||||
|
||||
if msg.channel_id != self.channel
|
||||
|| msg.webhook_id.is_some()
|
||||
|| msg.author.bot
|
||||
|| msg.content.starts_with('\\')
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if msg.content == "die" {
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let channel = msg.channel(&ctx).await.unwrap().guild().unwrap();
|
||||
|
||||
let typing = channel.clone().start_typing(&ctx.http).unwrap();
|
||||
|
||||
let after = (((self
|
||||
.reset_time
|
||||
.read()
|
||||
.await
|
||||
.max(time::OffsetDateTime::now_utc() - Duration::from_secs(5 * 60))
|
||||
.unix_timestamp()
|
||||
* 1000)
|
||||
- 1420070400000)
|
||||
<< 22) as u64;
|
||||
|
||||
let mut msgs = channel
|
||||
.messages(&ctx, |messages| messages.after(after))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
|
||||
let current_id = ctx.http.get_current_user().await.unwrap().id;
|
||||
|
||||
msgs.reverse();
|
||||
|
||||
for pk_msg in msgs.clone().into_iter().filter(|message| message.webhook_id.is_some()) {
|
||||
if let Some(id) = msgs.iter().rposition(|message| message.id < pk_msg.id && message.content.contains(&pk_msg.content)) {
|
||||
msgs.remove(id);
|
||||
}
|
||||
}
|
||||
|
||||
let ref_ctx = &ctx;
|
||||
|
||||
let msgs = msgs.into_iter().filter(|message| {
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
#[allow(clippy::needless_bool)]
|
||||
if message.webhook_id.is_some() && !message.content.trim().starts_with('\\') {
|
||||
true
|
||||
} else if message.content.starts_with('\\') {
|
||||
false
|
||||
} else if message.author.bot && message.author.id != current_id {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
|
||||
let last_msg = msgs.clone().last().unwrap();
|
||||
|
||||
let msgs = tokio_stream::iter(msgs)
|
||||
.map(|message| async move {
|
||||
if message.author.id == current_id {
|
||||
openai::Message {
|
||||
content: message.content.clone(),
|
||||
role: openai::Role::Assistant,
|
||||
}
|
||||
} else {
|
||||
openai::Message {
|
||||
content: format!(
|
||||
"{}: {}",
|
||||
message
|
||||
.author_nick(ref_ctx)
|
||||
.await
|
||||
.unwrap_or(message.author.name.clone()),
|
||||
message.content
|
||||
),
|
||||
role: openai::Role::User,
|
||||
}
|
||||
}
|
||||
})
|
||||
.buffered(100)
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.create_chat_completion(openai::ChatCompletion {
|
||||
model: self.model.clone(),
|
||||
messages: msgs,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = response.choices[0].message.content.as_str();
|
||||
|
||||
let is_appropriate = !self
|
||||
.client
|
||||
.create_moderation(openai::ModerationRequest {
|
||||
input: response.to_owned(),
|
||||
model: None,
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.results[0]
|
||||
.flagged;
|
||||
|
||||
typing.stop().unwrap();
|
||||
if is_appropriate {
|
||||
last_msg.reply(&ctx, response).await.unwrap();
|
||||
} else {
|
||||
msg.channel(&ctx)
|
||||
.await
|
||||
.unwrap()
|
||||
.guild()
|
||||
.unwrap()
|
||||
.send_message(&ctx, |message| {
|
||||
message.embed(|embed| {
|
||||
embed
|
||||
.title("Response flagged!")
|
||||
.description("The generated response may have been inappropriate.")
|
||||
.color(Color::RED)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn ready(&self, ctx: Context, ready: Ready) {
|
||||
println!("{} is connected!", ready.user.name);
|
||||
self.channel.say(&ctx, "\\Smolhaj reset").await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let api_key =
|
||||
env::var("OPENAI_API_KEY").expect("Error: OPENAI_API_KEY environment variable not found");
|
||||
|
||||
let token = env::var("TOKEN").expect("Error: TOKEN environment variable not found");
|
||||
|
||||
let channel = env::var("CHANNEL")
|
||||
.expect("Error: CHANNEL environment variable not found")
|
||||
.parse()
|
||||
.expect("Expected channel ID to be an integer");
|
||||
|
||||
let model = env::var("MODEL").expect("Error: MODEL environment variable not found");
|
||||
|
||||
// Replace with your organization ID (if applicable)
|
||||
let config = Configuration {
|
||||
let config = openai::Configuration {
|
||||
api_key,
|
||||
organization_id: None,
|
||||
};
|
||||
|
||||
let client = OpenAiApiClient::new(config);
|
||||
let openai_client = openai::OpenAiApiClient::new(config);
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: String::from("Say this is a test!"),
|
||||
}];
|
||||
let intents = GatewayIntents::GUILD_MESSAGES | GatewayIntents::MESSAGE_CONTENT;
|
||||
|
||||
let chat_completion = ChatCompletion {
|
||||
model: String::from("gpt-3.5-turbo"),
|
||||
messages,
|
||||
..Default::default()
|
||||
};
|
||||
let mut client = Client::builder(&token, intents)
|
||||
.event_handler(Handler {
|
||||
channel: ChannelId(channel),
|
||||
client: openai_client,
|
||||
reset_time: time::OffsetDateTime::now_utc().into(),
|
||||
model
|
||||
})
|
||||
.await
|
||||
.expect("Err creating client");
|
||||
|
||||
match client.create_chat_completion(chat_completion).await {
|
||||
Ok(response) => {
|
||||
println!(
|
||||
"Chat completion response: {:?}",
|
||||
response.choices.first().unwrap().message.content
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("Error: {}", error);
|
||||
}
|
||||
if let Err(why) = client.start().await {
|
||||
println!("Client error: {:?}", why);
|
||||
}
|
||||
|
||||
// let messages = vec![openai::Message {
|
||||
// role: openai::Role::User,
|
||||
// content: String::from("Say this is a test!"),
|
||||
// }];
|
||||
|
||||
// let chat_completion = openai::ChatCompletion {
|
||||
// model: String::from("gpt-3.5-turbo"),
|
||||
// messages,
|
||||
// ..Default::default()
|
||||
// };
|
||||
|
||||
// match client.create_chat_completion(chat_completion).await {
|
||||
// Ok(response) => {
|
||||
// println!(
|
||||
// "Chat completion response: {:?}",
|
||||
// response.choices.first().unwrap().message.content
|
||||
// );
|
||||
// }
|
||||
// Err(error) => {
|
||||
// eprintln!("Error: {}", error);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
Loading…
Reference in new issue