diff --git a/.gitignore b/.gitignore index ccec7880..8b3ecc23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,10 @@ .idea dist .env +config.yaml bin glide tmp coverage.txt +precommit.txt +pkg/providers/openai/openai_test.go \ No newline at end of file diff --git a/go.mod b/go.mod index dc213276..bc9bad0f 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.21.5 require ( github.com/cloudwego/hertz v0.7.3 - github.com/go-playground/validator/v10 v10.16.0 github.com/hertz-contrib/logger/zap v1.1.0 + github.com/joho/godotenv v1.5.1 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.2 go.uber.org/goleak v1.3.0 @@ -22,27 +22,21 @@ require ( github.com/cloudwego/netpoll v0.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/henrylee2cn/ameda v1.4.10 // indirect github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect - github.com/leodido/go-urn v1.2.4 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.8.2 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/crypto v0.16.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/sys v0.13.0 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index 328d351e..816a279f 100644 --- a/go.sum +++ b/go.sum @@ -20,16 +20,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= -github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -47,16 +37,12 @@ github.com/hertz-contrib/logger/zap v1.1.0 h1:4efINiIDJrXEtAFeEdDJvc3Hye0VFxp+0X github.com/hertz-contrib/logger/zap v1.1.0/go.mod h1:D/rJJgsYn+SGaHVfVqWS3vHTbbc7ODAlJO+6smWgTeE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -100,20 +86,14 @@ golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5P golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/pkg/buildAPIRequest.go b/pkg/buildAPIRequest.go deleted file mode 100644 index e8f280e8..00000000 --- a/pkg/buildAPIRequest.go +++ /dev/null @@ -1,58 +0,0 @@ -// this file contains the BuildAPIRequest function which takes in the provider name, params map, and mode and returns the providerConfig map and error -// The providerConfig map can be used to build the API request to the provider -package pkg - -import ( - "errors" - "fmt" - - "glide/pkg/providers" - "glide/pkg/providers/openai" - - "github.com/go-playground/validator/v10" -) - -type ProviderConfigs = pkg.ProviderConfigs - -// Initialize configList - -var configList = map[string]interface{}{ - "openai": openai.OpenAIConfig, -} - -// Create a new validator instance -var validate *validator.Validate = validator.New() - -func BuildAPIRequest(provider string, params map[string]string, mode string) (interface{}, error) { - // provider is the name of the provider, e.g. "openai", params is the map of parameters from the client, - // mode is the mode of the provider, e.g. "chat", configList is the list of provider configurations - var providerConfig map[string]interface{} - - if config, ok := configList[provider].(ProviderConfigs); ok { - if modeConfig, ok := config[mode].(map[string]interface{}); ok { - providerConfig = modeConfig - } - } - - // If the provider is not supported, return an error - if providerConfig == nil { - return nil, errors.New("unsupported provider") - } - - // Build the providerConfig map by iterating over the keys in the providerConfig map and checking if the key exists in the params map - - for key := range providerConfig { - if value, exists := params[key]; exists { - providerConfig[key] = value - } - } - - // Validate the providerConfig map using the validator package - err := validate.Struct(providerConfig) - if err != nil { - // Handle validation error - return nil, fmt.Errorf("validation error: %v", err) - } - // If everything is fine, return the providerConfig and nil error - return providerConfig, nil -} diff --git a/pkg/providers/openai/api.go b/pkg/providers/openai/api.go deleted file mode 100644 index 26f9e6e9..00000000 --- a/pkg/providers/openai/api.go +++ /dev/null @@ -1,29 +0,0 @@ -package openai - -import ( - "fmt" - "net/http" -) - -// provides the base URL and headers for the OpenAI API -type ProviderAPIConfig struct { - BaseURL string - Headers func(string) http.Header - Complete string - Chat string - Embed string -} - -func APIConfig(_ string) *ProviderAPIConfig { - return &ProviderAPIConfig{ - BaseURL: "https://api.openai.com/v1", - Headers: func(APIKey string) http.Header { - headers := make(http.Header) - headers.Set("Authorization", fmt.Sprintf("Bearer %s", APIKey)) - return headers - }, - Complete: "/completions", - Chat: "/chat/completions", - Embed: "/embeddings", - } -} diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 61baf4ca..3471a9c9 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -1,53 +1,234 @@ package openai -type ProviderConfig struct { - Model string `json:"model" validate:"required,lowercase"` - Messages string `json:"messages" validate:"required"` // does this need to be updated to []string? - MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` - Temperature int `json:"temperature" validate:"omitempty,gte=0,lte=2"` - TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"` - N int `json:"n" validate:"omitempty,gte=1"` - Stream bool `json:"stream" validate:"omitempty, boolean"` - Stop interface{} `json:"stop"` - PresencePenalty int `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"` - FrequencyPenalty int `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"` - LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` - User interface{} `json:"user"` - Seed interface{} `json:"seed" validate:"omitempty,gte=0"` - Tools []string `json:"tools"` - ToolChoice interface{} `json:"tool_choice"` - ResponseFormat interface{} `json:"response_format"` -} - -var defaultMessage = `[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello!" - } - ]` - -// Provide the request body for OpenAI's ChatCompletion API -func ChatDefaultConfig() ProviderConfig { - return ProviderConfig{ - Model: "gpt-3.5-turbo", - Messages: defaultMessage, - MaxTokens: 100, - Temperature: 1, +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + "strings" + + "glide/pkg/providers" + + "glide/pkg/telemetry" + + "go.uber.org/zap" +) + +const ( + defaultChatModel = "gpt-3.5-turbo" + defaultEndpoint = "/chat/completions" +) + +// Client is a client for the OpenAI API. +type ProviderClient struct { + BaseURL string `json:"baseURL"` + HTTPClient *http.Client `json:"httpClient"` + Telemetry *telemetry.Telemetry `json:"telemetry"` +} + +// ChatRequest is a request to complete a chat completion.. +type ChatRequest struct { + Model string `json:"model"` + Messages []map[string]string `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User interface{} `json:"user,omitempty"` + Seed interface{} `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + + // StreamingFunc is a function to be called for each chunk of a streaming response. + // Return an error to stop streaming early. + StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` +} + +// ChatMessage is a message in a chat request. +type ChatMessage struct { + // The role of the author of this message. One of system, user, or assistant. + Role string `json:"role"` + // The content of the message. + Content string `json:"content"` + // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name string `json:"name,omitempty"` +} + +// ChatChoice is a choice in a chat response. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatResponse is a response to a chat request. +type ChatResponse struct { + ID string `json:"id,omitempty"` + Created float64 `json:"created,omitempty"` + Choices []*ChatChoice `json:"choices,omitempty"` + Model string `json:"model,omitempty"` + Object string `json:"object,omitempty"` + Usage struct { + CompletionTokens float64 `json:"completion_tokens,omitempty"` + PromptTokens float64 `json:"prompt_tokens,omitempty"` + TotalTokens float64 `json:"total_tokens,omitempty"` + } `json:"usage,omitempty"` +} + +// Chat sends a chat request to the specified OpenAI model. +// +// Parameters: +// - payload: The user payload for the chat request. +// Returns: +// - *ChatResponse: a pointer to a ChatResponse +// - error: An error if the request failed. +func (c *ProviderClient) Chat(u *providers.UnifiedAPIData) (*ChatResponse, error) { + // Create a new chat request + c.Telemetry.Logger.Info("creating new chat request") + + chatRequest := c.CreateChatRequest(u) + + c.Telemetry.Logger.Info("chat request created") + + // Send the chat request + + resp, err := c.CreateChatResponse(context.Background(), chatRequest, u) + + return resp, err +} + +func (c *ProviderClient) CreateChatRequest(u *providers.UnifiedAPIData) *ChatRequest { + c.Telemetry.Logger.Info("creating chatRequest from payload") + + var messages []map[string]string + + // Add items from messageHistory first + messages = append(messages, u.MessageHistory...) + + // Add msg variable last + messages = append(messages, u.Message) + + // Iterate through unifiedData.Params and add them to the request, otherwise leave the default value + defaultParams := u.Params + + chatRequest := &ChatRequest{ + Model: u.Model, + Messages: messages, + Temperature: 0.8, TopP: 1, + MaxTokens: 100, N: 1, + StopWords: []string{}, Stream: false, - Stop: nil, - PresencePenalty: 0, FrequencyPenalty: 0, + PresencePenalty: 0, LogitBias: nil, User: nil, Seed: nil, - Tools: nil, + Tools: []string{}, ToolChoice: nil, ResponseFormat: nil, } + + chatRequestValue := reflect.ValueOf(chatRequest).Elem() + chatRequestType := chatRequestValue.Type() + + for i := 0; i < chatRequestValue.NumField(); i++ { + jsonTags := strings.Split(chatRequestType.Field(i).Tag.Get("json"), ",") + jsonTag := jsonTags[0] + + if value, ok := defaultParams[jsonTag]; ok { + fieldValue := chatRequestValue.Field(i) + fieldValue.Set(reflect.ValueOf(value)) + } + } + + // c.Telemetry.Logger.Info("chatRequest created", zap.Any("chatRequest body", chatRequest)) + + return chatRequest +} + +// CreateChatResponse creates chat Response. +func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { + _ = ctx // keep this for future use + + resp, err := c.createChatHTTP(r, u) + if err != nil { + return nil, err + } + + if len(resp.Choices) == 0 { + return nil, ErrEmptyResponse + } + + return resp, nil +} + +func (c *ProviderClient) createChatHTTP(payload *ChatRequest, u *providers.UnifiedAPIData) (*ChatResponse, error) { + c.Telemetry.Logger.Info("running createChatHttp") + + if payload.StreamingFunc != nil { + payload.Stream = true + } + // Build request payload + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + // Build request + if defaultBaseURL == "" { + c.Telemetry.Logger.Error("defaultBaseURL not set") + return nil, errors.New("baseURL not set") + } + + reqBody := bytes.NewBuffer(payloadBytes) + req, err := http.NewRequest(http.MethodPost, buildURL(defaultEndpoint), reqBody) + if err != nil { + c.Telemetry.Logger.Error(err.Error()) + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+u.APIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := providers.HTTPClient.Do(req) + if err != nil { + c.Telemetry.Logger.Error(err.Error()) + return nil, err + } + defer resp.Body.Close() + + c.Telemetry.Logger.Info("Response Code: ", zap.String("response_code", strconv.Itoa(resp.StatusCode))) + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.Telemetry.Logger.Error(err.Error()) + } + + c.Telemetry.Logger.Warn("Response Body: ", zap.String("response_body", string(bodyBytes))) + } + + // Parse response + var response ChatResponse + + return &response, json.NewDecoder(resp.Body).Decode(&response) +} + +func buildURL(suffix string) string { + // open ai implement: + return fmt.Sprintf("%s%s", defaultBaseURL, suffix) } diff --git a/pkg/providers/openai/index.go b/pkg/providers/openai/index.go deleted file mode 100644 index 7dffb4e8..00000000 --- a/pkg/providers/openai/index.go +++ /dev/null @@ -1,11 +0,0 @@ -package openai - -import ( - "glide/pkg/providers" -) - -// TODO: this needs to be imported into buildAPIRequest.go -var OpenAIConfig = pkg.ProviderConfigs{ - "api": APIConfig, - "chat": ChatDefaultConfig, -} diff --git a/pkg/providers/openai/openaiclient.go b/pkg/providers/openai/openaiclient.go new file mode 100644 index 00000000..3b9060ba --- /dev/null +++ b/pkg/providers/openai/openaiclient.go @@ -0,0 +1,50 @@ +// TODO: Explore resource pooling +// TODO: Optimize Type use +// TODO: Explore Hertz TLS & resource pooling +// OpenAI package provide a set of functions to interact with the OpenAI API. +package openai + +import ( + "errors" + + "glide/pkg/providers" + + "glide/pkg/telemetry" +) + +const ( + providerName = "openai" + defaultBaseURL = "https://api.openai.com/v1" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// OpenAiClient creates a new client for the OpenAI API. +// +// Parameters: +// - poolName: The name of the pool to connect to. +// - modelName: The name of the model to use. +// +// Returns: +// - *Client: A pointer to the created client. +// - error: An error if the client creation failed. +func Client() (*ProviderClient, error) { + tel, err := telemetry.NewTelemetry(&telemetry.Config{LogConfig: telemetry.NewLogConfig()}) + if err != nil { + return nil, err + } + + tel.Logger.Info("init openai provider client") + + // Create a new client + c := &ProviderClient{ + BaseURL: defaultBaseURL, + HTTPClient: providers.HTTPClient, + Telemetry: tel, + } + + return c, nil +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index b27094d5..2c07565f 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,3 +1,58 @@ -package pkg +package providers -type ProviderConfigs map[string]interface{} +import ( + "net/http" + "time" +) + +type GatewayConfig struct { + Gateway PoolsConfig `yaml:"gateway" validate:"required"` +} +type PoolsConfig struct { + Pools []Pool `yaml:"pools" validate:"required"` +} + +type Pool struct { + Name string `yaml:"pool" validate:"required"` + Balancing string `yaml:"balancing" validate:"required"` + Providers []Provider `yaml:"providers" validate:"required"` +} + +type Provider struct { + Name string `yaml:"name" validate:"required"` + Model string `yaml:"model"` + APIKey string `yaml:"api_key" validate:"required"` + TimeoutMs int `yaml:"timeout_ms,omitempty"` + DefaultParams map[string]interface{} `yaml:"default_params,omitempty"` +} + +type ProviderVars struct { + Name string `yaml:"name"` + ChatBaseURL string `yaml:"chatBaseURL"` +} + +type RequestBody struct { + Message []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + MessageHistory []string `json:"messageHistory"` +} + +// Variables + +var HTTPClient = &http.Client{ + Timeout: time.Second * 30, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, +} + +type UnifiedAPIData struct { + Model string `json:"model"` + APIKey string `json:"api_key"` + Params map[string]interface{} `json:"params"` + Message map[string]string `json:"message"` + MessageHistory []map[string]string `json:"messageHistory"` +}