smolhaj-ng/src/openai.rs
2023-03-17 20:22:07 +09:00

186 lines
5.1 KiB
Rust

// (fittingly) Written by GPT-4
use anyhow::{anyhow, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
const BASE_URL: &str = "https://api.openai.com";
#[derive(Debug, Serialize, Deserialize)]
pub struct Configuration {
pub api_key: String,
pub organization_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ChatCompletion {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub usage: ChatCompletionUsage,
pub choices: Vec<ChatCompletionChoice>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionUsage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
pub message: Message,
pub finish_reason: String,
pub index: i32,
}
#[derive(Debug, Serialize)]
pub struct ModerationRequest {
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ModerationResponse {
pub id: String,
pub model: String,
pub results: Vec<ModerationResult>,
}
#[derive(Debug, Deserialize)]
pub struct ModerationResult {
pub categories: HashMap<String, bool>,
pub category_scores: HashMap<String, f64>,
pub flagged: bool,
}
pub struct OpenAiApiClient {
configuration: Configuration,
client: Client,
}
impl OpenAiApiClient {
pub fn new(configuration: Configuration) -> Self {
OpenAiApiClient {
configuration,
client: Client::new(),
}
}
pub async fn create_chat_completion(
&self,
chat_completion: ChatCompletion,
) -> Result<ChatCompletionResponse> {
let url = format!("{}/v1/chat/completions", BASE_URL);
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header(
"Authorization",
format!("Bearer {}", self.configuration.api_key),
)
.header(
"OpenAI-Organization",
self.configuration
.organization_id
.as_ref()
.unwrap_or(&"".to_string()),
)
.json(&chat_completion)
.send()
.await?;
// Check for server errors
if !response.status().is_success() {
return Err(anyhow!(
"Server returned an error (status code: {}): {}",
response.status(),
response.text().await?
));
}
let chat_completion_response: ChatCompletionResponse = response.json().await?;
Ok(chat_completion_response)
}
pub async fn create_moderation(
&self,
moderation_request: ModerationRequest,
) -> Result<ModerationResponse> {
let url = format!("{}/v1/moderations", BASE_URL);
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header(
"Authorization",
format!("Bearer {}", self.configuration.api_key),
)
.header(
"OpenAI-Organization",
self.configuration
.organization_id
.as_ref()
.unwrap_or(&"".to_string()),
)
.json(&moderation_request)
.send()
.await?;
// Check for server errors
if !response.status().is_success() {
return Err(anyhow!(
"Server returned an error (status code: {}): {}",
response.status(),
response.text().await?
));
}
let moderation_response: ModerationResponse = response.json().await?;
Ok(moderation_response)
}
}