// (fittingly) Written by GPT-4 use anyhow::{anyhow, Result}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; const BASE_URL: &str = "https://api.openai.com"; #[derive(Debug, Serialize, Deserialize)] pub struct Configuration { pub api_key: String, pub organization_id: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum Role { System, User, Assistant, } #[derive(Debug, Serialize, Deserialize)] pub struct Message { pub role: Role, pub content: String, } #[derive(Debug, Serialize, Deserialize, Default)] pub struct ChatCompletion { pub model: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] pub n: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stream: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub presence_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] pub frequency_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] pub logit_bias: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub user: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct ChatCompletionResponse { pub id: String, pub object: String, pub created: i64, pub model: String, pub usage: ChatCompletionUsage, pub choices: Vec, } #[derive(Debug, Serialize, Deserialize)] pub struct ChatCompletionUsage { pub prompt_tokens: i32, pub completion_tokens: i32, pub total_tokens: i32, } #[derive(Debug, Serialize, Deserialize)] pub struct ChatCompletionChoice { pub message: Message, pub finish_reason: String, pub index: i32, } #[derive(Debug, Serialize)] pub struct ModerationRequest { pub input: String, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, } #[derive(Debug, Deserialize)] pub struct ModerationResponse { pub id: String, pub model: String, pub results: Vec, } #[derive(Debug, Deserialize)] pub struct ModerationResult { pub categories: HashMap, pub category_scores: HashMap, pub flagged: bool, } pub struct OpenAiApiClient { configuration: Configuration, client: Client, } impl OpenAiApiClient { pub fn new(configuration: Configuration) -> Self { OpenAiApiClient { configuration, client: Client::new(), } } pub async fn create_chat_completion( &self, chat_completion: ChatCompletion, ) -> Result { let url = format!("{}/v1/chat/completions", BASE_URL); let response = self .client .post(&url) .header("Content-Type", "application/json") .header( "Authorization", format!("Bearer {}", self.configuration.api_key), ) .header( "OpenAI-Organization", self.configuration .organization_id .as_ref() .unwrap_or(&"".to_string()), ) .json(&chat_completion) .send() .await?; // Check for server errors if !response.status().is_success() { return Err(anyhow!( "Server returned an error (status code: {}): {}", response.status(), response.text().await? )); } let chat_completion_response: ChatCompletionResponse = response.json().await?; Ok(chat_completion_response) } pub async fn create_moderation( &self, moderation_request: ModerationRequest, ) -> Result { let url = format!("{}/v1/moderations", BASE_URL); let response = self .client .post(&url) .header("Content-Type", "application/json") .header( "Authorization", format!("Bearer {}", self.configuration.api_key), ) .header( "OpenAI-Organization", self.configuration .organization_id .as_ref() .unwrap_or(&"".to_string()), ) .json(&moderation_request) .send() .await?; // Check for server errors if !response.status().is_success() { return Err(anyhow!( "Server returned an error (status code: {}): {}", response.status(), response.text().await? )); } let moderation_response: ModerationResponse = response.json().await?; Ok(moderation_response) } }