test
This commit is contained in:
commit
c7a0522ee1
8 changed files with 2173 additions and 0 deletions
1
.envrc
Normal file
1
.envrc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
use flake
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
/target
|
||||||
|
/.direnv
|
1832
Cargo.lock
generated
Normal file
1832
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
14
Cargo.toml
Normal file
14
Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "smolhaj-ng"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.69"
|
||||||
|
reqwest = "0.11.14"
|
||||||
|
serde = { version = "1.0.156", features = ["derive"] }
|
||||||
|
serde_json = "1.0.94"
|
||||||
|
serenity = "0.11.5"
|
||||||
|
tokio = { version = "1.26.0", features = ["full"] }
|
77
flake.lock
Normal file
77
flake.lock
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"naersk": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1671096816,
|
||||||
|
"narHash": "sha256-ezQCsNgmpUHdZANDCILm3RvtO1xH8uujk/+EqNvzIOg=",
|
||||||
|
"owner": "nix-community",
|
||||||
|
"repo": "naersk",
|
||||||
|
"rev": "d998160d6a076cfe8f9741e56aeec7e267e3e114",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-community",
|
||||||
|
"ref": "master",
|
||||||
|
"repo": "naersk",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1678875422,
|
||||||
|
"narHash": "sha256-T3o6NcQPwXjxJMn2shz86Chch4ljXgZn746c2caGxd8=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "126f49a01de5b7e35a43fd43f891ecf6d3a51459",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"id": "nixpkgs",
|
||||||
|
"type": "indirect"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs_2": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1678875422,
|
||||||
|
"narHash": "sha256-T3o6NcQPwXjxJMn2shz86Chch4ljXgZn746c2caGxd8=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "126f49a01de5b7e35a43fd43f891ecf6d3a51459",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixpkgs-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"naersk": "naersk",
|
||||||
|
"nixpkgs": "nixpkgs_2",
|
||||||
|
"utils": "utils"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"utils": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1678901627,
|
||||||
|
"narHash": "sha256-U02riOqrKKzwjsxc/400XnElV+UtPUQWpANPlyazjH0=",
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"rev": "93a2b84fc4b70d9e089d029deacc3583435c2ed6",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "numtide",
|
||||||
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
21
flake.nix
Normal file
21
flake.nix
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
inputs = {
|
||||||
|
naersk.url = "github:nix-community/naersk/master";
|
||||||
|
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||||
|
utils.url = "github:numtide/flake-utils";
|
||||||
|
};
|
||||||
|
|
||||||
|
outputs = { self, nixpkgs, utils, naersk }:
|
||||||
|
utils.lib.eachDefaultSystem (system:
|
||||||
|
let
|
||||||
|
pkgs = import nixpkgs { inherit system; };
|
||||||
|
naersk-lib = pkgs.callPackage naersk { };
|
||||||
|
in
|
||||||
|
{
|
||||||
|
defaultPackage = naersk-lib.buildPackage ./.;
|
||||||
|
devShell = with pkgs; mkShell {
|
||||||
|
buildInputs = [ cargo rustc rustfmt pre-commit rustPackages.clippy pkg-config openssl ];
|
||||||
|
RUST_SRC_PATH = rustPlatform.rustLibSrc;
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
40
src/main.rs
Normal file
40
src/main.rs
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
mod openai;
|
||||||
|
use openai::{ChatCompletion, Configuration, Message, OpenAiApiClient, Role};
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let api_key =
|
||||||
|
env::var("OPENAI_API_KEY").expect("Error: OPENAI_API_KEY environment variable not found");
|
||||||
|
|
||||||
|
// Replace with your organization ID (if applicable)
|
||||||
|
let config = Configuration {
|
||||||
|
api_key,
|
||||||
|
organization_id: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = OpenAiApiClient::new(config);
|
||||||
|
|
||||||
|
let messages = vec![Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: String::from("Say this is a test!"),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let chat_completion = ChatCompletion {
|
||||||
|
model: String::from("gpt-3.5-turbo"),
|
||||||
|
messages,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
match client.create_chat_completion(chat_completion).await {
|
||||||
|
Ok(response) => {
|
||||||
|
println!(
|
||||||
|
"Chat completion response: {:?}",
|
||||||
|
response.choices.first().unwrap().message.content
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
eprintln!("Error: {}", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
186
src/openai.rs
Normal file
186
src/openai.rs
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
// (fittingly) Written by GPT-4
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
const BASE_URL: &str = "https://api.openai.com";
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Configuration {
|
||||||
|
pub api_key: String,
|
||||||
|
pub organization_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
System,
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Message {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||||
|
pub struct ChatCompletion {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<usize>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<usize>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<HashMap<String, f64>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: i64,
|
||||||
|
pub model: String,
|
||||||
|
pub usage: ChatCompletionUsage,
|
||||||
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionUsage {
|
||||||
|
pub prompt_tokens: i32,
|
||||||
|
pub completion_tokens: i32,
|
||||||
|
pub total_tokens: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionChoice {
|
||||||
|
pub message: Message,
|
||||||
|
pub finish_reason: String,
|
||||||
|
pub index: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ModerationRequest {
|
||||||
|
pub input: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ModerationResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: String,
|
||||||
|
pub results: Vec<ModerationResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ModerationResult {
|
||||||
|
pub categories: HashMap<String, bool>,
|
||||||
|
pub category_scores: HashMap<String, f64>,
|
||||||
|
pub flagged: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpenAiApiClient {
|
||||||
|
configuration: Configuration,
|
||||||
|
client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAiApiClient {
|
||||||
|
pub fn new(configuration: Configuration) -> Self {
|
||||||
|
OpenAiApiClient {
|
||||||
|
configuration,
|
||||||
|
client: Client::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_chat_completion(
|
||||||
|
&self,
|
||||||
|
chat_completion: ChatCompletion,
|
||||||
|
) -> Result<ChatCompletionResponse> {
|
||||||
|
let url = format!("{}/v1/chat/completions", BASE_URL);
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header(
|
||||||
|
"Authorization",
|
||||||
|
format!("Bearer {}", self.configuration.api_key),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
"OpenAI-Organization",
|
||||||
|
self.configuration
|
||||||
|
.organization_id
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&"".to_string()),
|
||||||
|
)
|
||||||
|
.json(&chat_completion)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Check for server errors
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Server returned an error (status code: {}): {}",
|
||||||
|
response.status(),
|
||||||
|
response.text().await?
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_completion_response: ChatCompletionResponse = response.json().await?;
|
||||||
|
Ok(chat_completion_response)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_moderation(
|
||||||
|
&self,
|
||||||
|
moderation_request: ModerationRequest,
|
||||||
|
) -> Result<ModerationResponse> {
|
||||||
|
let url = format!("{}/v1/moderations", BASE_URL);
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header(
|
||||||
|
"Authorization",
|
||||||
|
format!("Bearer {}", self.configuration.api_key),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
"OpenAI-Organization",
|
||||||
|
self.configuration
|
||||||
|
.organization_id
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&"".to_string()),
|
||||||
|
)
|
||||||
|
.json(&moderation_request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Check for server errors
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Server returned an error (status code: {}): {}",
|
||||||
|
response.status(),
|
||||||
|
response.text().await?
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let moderation_response: ModerationResponse = response.json().await?;
|
||||||
|
Ok(moderation_response)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue