186 lines
5.1 KiB
Rust
186 lines
5.1 KiB
Rust
// (fittingly) Written by GPT-4
|
|
use anyhow::{anyhow, Result};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
|
|
const BASE_URL: &str = "https://api.openai.com";
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct Configuration {
|
|
pub api_key: String,
|
|
pub organization_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum Role {
|
|
System,
|
|
User,
|
|
Assistant,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct Message {
|
|
pub role: Role,
|
|
pub content: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
pub struct ChatCompletion {
|
|
pub model: String,
|
|
pub messages: Vec<Message>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub temperature: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub top_p: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub n: Option<usize>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub stream: Option<bool>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub stop: Option<Vec<String>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub max_tokens: Option<usize>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub presence_penalty: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub frequency_penalty: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub logit_bias: Option<HashMap<String, f64>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub user: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct ChatCompletionResponse {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub created: i64,
|
|
pub model: String,
|
|
pub usage: ChatCompletionUsage,
|
|
pub choices: Vec<ChatCompletionChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct ChatCompletionUsage {
|
|
pub prompt_tokens: i32,
|
|
pub completion_tokens: i32,
|
|
pub total_tokens: i32,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct ChatCompletionChoice {
|
|
pub message: Message,
|
|
pub finish_reason: String,
|
|
pub index: i32,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct ModerationRequest {
|
|
pub input: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub model: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ModerationResponse {
|
|
pub id: String,
|
|
pub model: String,
|
|
pub results: Vec<ModerationResult>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ModerationResult {
|
|
pub categories: HashMap<String, bool>,
|
|
pub category_scores: HashMap<String, f64>,
|
|
pub flagged: bool,
|
|
}
|
|
|
|
pub struct OpenAiApiClient {
|
|
configuration: Configuration,
|
|
client: Client,
|
|
}
|
|
|
|
impl OpenAiApiClient {
|
|
pub fn new(configuration: Configuration) -> Self {
|
|
OpenAiApiClient {
|
|
configuration,
|
|
client: Client::new(),
|
|
}
|
|
}
|
|
|
|
pub async fn create_chat_completion(
|
|
&self,
|
|
chat_completion: ChatCompletion,
|
|
) -> Result<ChatCompletionResponse> {
|
|
let url = format!("{}/v1/chat/completions", BASE_URL);
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.header("Content-Type", "application/json")
|
|
.header(
|
|
"Authorization",
|
|
format!("Bearer {}", self.configuration.api_key),
|
|
)
|
|
.header(
|
|
"OpenAI-Organization",
|
|
self.configuration
|
|
.organization_id
|
|
.as_ref()
|
|
.unwrap_or(&"".to_string()),
|
|
)
|
|
.json(&chat_completion)
|
|
.send()
|
|
.await?;
|
|
|
|
// Check for server errors
|
|
if !response.status().is_success() {
|
|
return Err(anyhow!(
|
|
"Server returned an error (status code: {}): {}",
|
|
response.status(),
|
|
response.text().await?
|
|
));
|
|
}
|
|
|
|
let chat_completion_response: ChatCompletionResponse = response.json().await?;
|
|
Ok(chat_completion_response)
|
|
}
|
|
|
|
pub async fn create_moderation(
|
|
&self,
|
|
moderation_request: ModerationRequest,
|
|
) -> Result<ModerationResponse> {
|
|
let url = format!("{}/v1/moderations", BASE_URL);
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.header("Content-Type", "application/json")
|
|
.header(
|
|
"Authorization",
|
|
format!("Bearer {}", self.configuration.api_key),
|
|
)
|
|
.header(
|
|
"OpenAI-Organization",
|
|
self.configuration
|
|
.organization_id
|
|
.as_ref()
|
|
.unwrap_or(&"".to_string()),
|
|
)
|
|
.json(&moderation_request)
|
|
.send()
|
|
.await?;
|
|
|
|
// Check for server errors
|
|
if !response.status().is_success() {
|
|
return Err(anyhow!(
|
|
"Server returned an error (status code: {}): {}",
|
|
response.status(),
|
|
response.text().await?
|
|
));
|
|
}
|
|
|
|
let moderation_response: ModerationResponse = response.json().await?;
|
|
Ok(moderation_response)
|
|
}
|
|
}
|