This commit is contained in:
Skye 2023-03-17 20:22:07 +09:00
commit c7a0522ee1
Signed by: me
GPG key ID: 0104BC05F41B77B8
8 changed files with 2173 additions and 0 deletions

1
.envrc Normal file
View file

@ -0,0 +1 @@
use flake

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target
/.direnv

1832
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

14
Cargo.toml Normal file
View file

@ -0,0 +1,14 @@
[package]
name = "smolhaj-ng"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.69"
reqwest = "0.11.14"
serde = { version = "1.0.156", features = ["derive"] }
serde_json = "1.0.94"
serenity = "0.11.5"
tokio = { version = "1.26.0", features = ["full"] }

77
flake.lock Normal file
View file

@ -0,0 +1,77 @@
{
"nodes": {
"naersk": {
"inputs": {
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1671096816,
"narHash": "sha256-ezQCsNgmpUHdZANDCILm3RvtO1xH8uujk/+EqNvzIOg=",
"owner": "nix-community",
"repo": "naersk",
"rev": "d998160d6a076cfe8f9741e56aeec7e267e3e114",
"type": "github"
},
"original": {
"owner": "nix-community",
"ref": "master",
"repo": "naersk",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1678875422,
"narHash": "sha256-T3o6NcQPwXjxJMn2shz86Chch4ljXgZn746c2caGxd8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "126f49a01de5b7e35a43fd43f891ecf6d3a51459",
"type": "github"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"nixpkgs_2": {
"locked": {
"lastModified": 1678875422,
"narHash": "sha256-T3o6NcQPwXjxJMn2shz86Chch4ljXgZn746c2caGxd8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "126f49a01de5b7e35a43fd43f891ecf6d3a51459",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixpkgs-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"naersk": "naersk",
"nixpkgs": "nixpkgs_2",
"utils": "utils"
}
},
"utils": {
"locked": {
"lastModified": 1678901627,
"narHash": "sha256-U02riOqrKKzwjsxc/400XnElV+UtPUQWpANPlyazjH0=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "93a2b84fc4b70d9e089d029deacc3583435c2ed6",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

21
flake.nix Normal file
View file

@ -0,0 +1,21 @@
{
inputs = {
naersk.url = "github:nix-community/naersk/master";
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, utils, naersk }:
utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs { inherit system; };
naersk-lib = pkgs.callPackage naersk { };
in
{
defaultPackage = naersk-lib.buildPackage ./.;
devShell = with pkgs; mkShell {
buildInputs = [ cargo rustc rustfmt pre-commit rustPackages.clippy pkg-config openssl ];
RUST_SRC_PATH = rustPlatform.rustLibSrc;
};
});
}

40
src/main.rs Normal file
View file

@ -0,0 +1,40 @@
mod openai;
use openai::{ChatCompletion, Configuration, Message, OpenAiApiClient, Role};
use std::env;
#[tokio::main]
async fn main() {
let api_key =
env::var("OPENAI_API_KEY").expect("Error: OPENAI_API_KEY environment variable not found");
// Replace with your organization ID (if applicable)
let config = Configuration {
api_key,
organization_id: None,
};
let client = OpenAiApiClient::new(config);
let messages = vec![Message {
role: Role::User,
content: String::from("Say this is a test!"),
}];
let chat_completion = ChatCompletion {
model: String::from("gpt-3.5-turbo"),
messages,
..Default::default()
};
match client.create_chat_completion(chat_completion).await {
Ok(response) => {
println!(
"Chat completion response: {:?}",
response.choices.first().unwrap().message.content
);
}
Err(error) => {
eprintln!("Error: {}", error);
}
}
}

186
src/openai.rs Normal file
View file

@ -0,0 +1,186 @@
// (fittingly) Written by GPT-4
use anyhow::{anyhow, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
const BASE_URL: &str = "https://api.openai.com";
#[derive(Debug, Serialize, Deserialize)]
pub struct Configuration {
pub api_key: String,
pub organization_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ChatCompletion {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub usage: ChatCompletionUsage,
pub choices: Vec<ChatCompletionChoice>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionUsage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
pub message: Message,
pub finish_reason: String,
pub index: i32,
}
#[derive(Debug, Serialize)]
pub struct ModerationRequest {
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ModerationResponse {
pub id: String,
pub model: String,
pub results: Vec<ModerationResult>,
}
#[derive(Debug, Deserialize)]
pub struct ModerationResult {
pub categories: HashMap<String, bool>,
pub category_scores: HashMap<String, f64>,
pub flagged: bool,
}
pub struct OpenAiApiClient {
configuration: Configuration,
client: Client,
}
impl OpenAiApiClient {
pub fn new(configuration: Configuration) -> Self {
OpenAiApiClient {
configuration,
client: Client::new(),
}
}
pub async fn create_chat_completion(
&self,
chat_completion: ChatCompletion,
) -> Result<ChatCompletionResponse> {
let url = format!("{}/v1/chat/completions", BASE_URL);
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header(
"Authorization",
format!("Bearer {}", self.configuration.api_key),
)
.header(
"OpenAI-Organization",
self.configuration
.organization_id
.as_ref()
.unwrap_or(&"".to_string()),
)
.json(&chat_completion)
.send()
.await?;
// Check for server errors
if !response.status().is_success() {
return Err(anyhow!(
"Server returned an error (status code: {}): {}",
response.status(),
response.text().await?
));
}
let chat_completion_response: ChatCompletionResponse = response.json().await?;
Ok(chat_completion_response)
}
pub async fn create_moderation(
&self,
moderation_request: ModerationRequest,
) -> Result<ModerationResponse> {
let url = format!("{}/v1/moderations", BASE_URL);
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header(
"Authorization",
format!("Bearer {}", self.configuration.api_key),
)
.header(
"OpenAI-Organization",
self.configuration
.organization_id
.as_ref()
.unwrap_or(&"".to_string()),
)
.json(&moderation_request)
.send()
.await?;
// Check for server errors
if !response.status().is_success() {
return Err(anyhow!(
"Server returned an error (status code: {}): {}",
response.status(),
response.text().await?
));
}
let moderation_response: ModerationResponse = response.json().await?;
Ok(moderation_response)
}
}