文心雕龙 ERNIE-Bot 4.0 模型流和非流 API 调用(SpringBoot+OkHttp3+SSE+WebSocket)
博客地址:https://www.yuque.com/autunomy/emwi09
摘要
本文的主要内容是针对于文心一言的ERNIE-Bot 4.0模型的API调用,使用到的技术有JDK1.8 , OkHttp3 , WebSocket , SSE , SpringBoot。API有流式和非流式两种,这里都进行了详细的代码编写,并且针对于流式API给出了两种协议的写法并且是前后端交互式的写法,看完本文后几乎所有的大模型的流式、非流式API都可以轻松调用了。
准备工作
第一步申请资格:千帆服务体验申请
第二步打开控制台:百度智能云千帆大模型平台
点击模型广场并选择自己需要模型,打开API文档
打开文档之后会有两个部分,一个是API在线调试,一个是AccessToken的获取,那接下来我们先获取AccessToken
接下来就可以看到获取API Key的方式
点击创建应用来创建API Key,应用的名字可以随便写
最后就可以获取一个API Key 和 Secret Key,这两个就可以帮助我们获取AccessToken,接下来返回之前的AccessToken的获取页面,里面就会有详细的请求入参和响应,可以根据官方的文档获取AccessToken
AccessToken的作用就是一个通行证,每次我们与文心一言对话的时候都需要带着这个通行证才能访问,接下来我用java代码使用springboot+okHttp3的方式来演示AccessToken的获取方式
使用java代码获取AccessToken
第一步肯定是引入依赖
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.9.3</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.79</version>
</dependency>
第二步是配置okHttp
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.net.ssl.*;
import java.security.*;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;
@Configuration
public class OkHttpConfiguration {
@Value("${ok.http.connect-timeout}")
private Integer connectTimeout;
@Value("${ok.http.read-timeout}")
private Integer readTimeout;
@Value("${ok.http.write-timeout}")
private Integer writeTimeout;
@Value("${ok.http.max-idle-connections}")
private Integer maxIdleConnections;
@Value("${ok.http.keep-alive-duration}")
private Long keepAliveDuration;
@Bean
public OkHttpClient okHttpClient() {
return new OkHttpClient.Builder()
.sslSocketFactory(sslSocketFactory(), x509TrustManager())
// 是否开启缓存
.retryOnConnectionFailure(false)
.connectionPool(pool())
.connectTimeout(connectTimeout, TimeUnit.SECONDS)
.readTimeout(readTimeout, TimeUnit.SECONDS)
.writeTimeout(writeTimeout,TimeUnit.SECONDS)
.hostnameVerifier((hostname, session) -> true)
// 设置代理
// .proxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 8888)))
// 拦截器
// .addInterceptor()
.build();
}
@Bean
public X509TrustManager x509TrustManager() {
return new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
};
}
@Bean
public SSLSocketFactory sslSocketFactory() {
try {
// 信任任何链接
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, new TrustManager[]{x509TrustManager()}, new SecureRandom());
return sslContext.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
e.printStackTrace();
}
return null;
}
@Bean
public ConnectionPool pool() {
return new ConnectionPool(maxIdleConnections, keepAliveDuration, TimeUnit.SECONDS);
}
}
同时需要在application.properties配置文件中进行配置
# okhttp3配置
ok.http.connect-timeout=30
ok.http.read-timeout=30
ok.http.write-timeout=30
# 连接池中整体的空闲连接的最大数量
ok.http.max-idle-connections=200
# 连接空闲时间最多为 300 秒
ok.http.keep-alive-duration=300
第三步是配置一些全局变量
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.MediaType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import java.io.IOException;
import java.util.Date;
/**
* @author hty
* @date 2023-11-18 9:42
* @email 1156388927@qq.com
* @description
*/
@Configuration
@Slf4j
public class WenXinConfig {
@Value("${wenxin.apiKey}")
public String API_KEY;
@Value("${wenxin.secretKey}")
public String SECRET_KEY;
@Value("${wenxin.accessTokenUrl}")
public String ACCESS_TOKEN_URL;
@Value("${wenxin.ERNIE-Bot4.0URL}")
public String ERNIE_Bot_4_0_URL;
//过期时间为30天
public String ACCESS_TOKEN = null;
public String REFRESH_TOKEN = null;
public Date CREATE_TIME = null;//accessToken创建时间
public Date EXPIRATION_TIME = null;//accessToken到期时间
/**
* 获取accessToken
* @return true表示成功 false表示失败
*/
public synchronized String flushAccessToken(){
//判断当前AccessToken是否为空且判断是否过期
if(ACCESS_TOKEN != null && EXPIRATION_TIME.getTime() > CREATE_TIME.getTime()) return ACCESS_TOKEN;
//构造请求体 包含请求参数和请求头等信息
RequestBody body = RequestBody.create(MediaType.parse("application/json"),"");
Request request = new Request.Builder()
.url(ACCESS_TOKEN_URL+"?client_id="+API_KEY+"&client_secret="+SECRET_KEY+"&grant_type=client_credentials")
.method("POST", body)
.addHeader("Content-Type", "application/json")
.addHeader("Accept", "application/json")
.build();
OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
String response = null;
try {
//请求
response = HTTP_CLIENT.newCall(request).execute().body().string();
} catch (IOException e) {
log.error("ACCESS_TOKEN获取失败");
return null;
}
//刷新令牌以及更新令牌创建时间和过期时间
JSONObject jsonObject = JSON.parseObject(response);
ACCESS_TOKEN = jsonObject.getString("access_token");
REFRESH_TOKEN = jsonObject.getString("refresh_token");
CREATE_TIME = new Date();
EXPIRATION_TIME = new Date(Long.parseLong(jsonObject.getString("expires_in")) + CREATE_TIME.getTime());
return ACCESS_TOKEN;
}
}
可以看到代码中有@Value注解,表示需要注入属性,所以继续在application.properties配置文件中进行配置
wenxin.apiKey=你的apiKey
wenxin.secretKey=你的secretKey
#获取AccessToken的url地址
wenxin.accessTokenUrl=https://aip.baidubce.com/oauth/2.0/token
#文心ERNIE-Bot4.0模型访问地址
wenxin.ERNIE-Bot4.0URL=https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro
非流式回答
至此配置完成,接下来就是编写代码来进行访问,我使用的是springboot框架,所以就在controller中编写一个方法来实现。
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.hty.config.WenXinConfig;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@Slf4j
public class TestController {
@Resource
private WenXinConfig wenXinConfig;
//历史对话,需要按照user,assistant
List<Map<String,String>> messages = new ArrayList<>();
/**
* 非流式问答
* @param question 用户的问题
* @return
* @throws IOException
*/
@PostMapping("/test1")
public String test1(String question) throws IOException {
if(question == null || question.equals("")){
return "请输入问题";
}
String responseJson = null;
//先获取令牌然后才能访问api
if (wenXinConfig.flushAccessToken() != null) {
HashMap<String, String> user = new HashMap<>();
user.put("role","user");
user.put("content",question);
messages.add(user);
String requestJson = constructRequestJson(1,0.95,0.8,1.0,false,messages);
RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
Request request = new Request.Builder()
.url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
try {
responseJson = HTTP_CLIENT.newCall(request).execute().body().string();
//将回复的内容转为一个JSONObject
JSONObject responseObject = JSON.parseObject(responseJson);
//将回复的内容添加到消息中
HashMap<String, String> assistant = new HashMap<>();
assistant.put("role","assistant");
assistant.put("content",responseObject.getString("result"));
messages.add(assistant);
} catch (IOException e) {
log.error("网络有问题");
return "网络有问题,请稍后重试";
}
}
return responseJson;
}
/**
* 构造请求的请求参数
* @param userId
* @param temperature
* @param topP
* @param penaltyScore
* @param messages
* @return
*/
public String constructRequestJson(Integer userId,
Double temperature,
Double topP,
Double penaltyScore,
boolean stream,
List<Map<String, String>> messages) {
Map<String,Object> request = new HashMap<>();
request.put("user_id",userId.toString());
request.put("temperature",temperature);
request.put("top_p",topP);
request.put("penalty_score",penaltyScore);
request.put("stream",stream);
request.put("messages",messages);
System.out.println(JSON.toJSONString(request));
return JSON.toJSONString(request);
}
}
使用PostMan调用接口即可获取到回复
流式回答-输出到控制台
流式回答就是在原来非流式回答的基础上将stream这个参数从false改为true即可,但是还要注意的一个事情就是响应中的内容是会有一点变化的,会在原来响应的内容的前面加上data
这个字符,需要在转jsonObject的时候注意
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.hty.config.WenXinConfig;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@Slf4j
public class TestController {
@Resource
private WenXinConfig wenXinConfig;
//历史对话,需要按照user,assistant
List<Map<String,String>> messages = new ArrayList<>();
/**
* 流式回答
* @return
*/
@PostMapping("/test2")
public String test2(String question){
OkHttpClient client = new OkHttpClient();
HashMap<String, String> user = new HashMap<>();
user.put("role","user");
user.put("content",question);
messages.add(user);
String requestJson = constructRequestJson(1,0.95,0.8,1.0,true,messages);
RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
Request request = new Request.Builder()
.url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
StringBuilder answer = new StringBuilder();
// 发起异步请求
try {
Response response = client.newCall(request).execute();
// 检查响应是否成功
if (response.isSuccessful()) {
// 获取响应流
try (ResponseBody responseBody = response.body()) {
if (responseBody != null) {
InputStream inputStream = responseBody.byteStream();
// 以流的方式处理响应内容,输出到控制台
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
// 在控制台输出每个数据块
System.out.write(buffer, 0, bytesRead);
//将结果汇总起来
answer.append(new String(buffer, 0, bytesRead));
}
}
}
} else {
System.out.println("Unexpected code " + response);
}
} catch (IOException e) {
log.error("流式请求出错");
throw new RuntimeException(e);
}
//将回复的内容添加到消息中
HashMap<String, String> assistant = new HashMap<>();
assistant.put("role","assistant");
assistant.put("content","");
//取出我们需要的内容,也就是result部分
String[] answerArray = answer.toString().split("data: ");
for (int i=1;i<answerArray.length;++i) {
answerArray[i] = answerArray[i].substring(0,answerArray[i].length() - 2);
assistant.put("content",assistant.get("content") + JSON.parseObject(answerArray[i]).get("result"));
}
messages.add(assistant);
return assistant.get("content");
}
/**
* 构造请求的请求参数
* @param userId
* @param temperature
* @param topP
* @param penaltyScore
* @param messages
* @return
*/
public String constructRequestJson(Integer userId,
Double temperature,
Double topP,
Double penaltyScore,
boolean stream,
List<Map<String, String>> messages) {
Map<String,Object> request = new HashMap<>();
request.put("user_id",userId.toString());
request.put("temperature",temperature);
request.put("top_p",topP);
request.put("penalty_score",penaltyScore);
request.put("stream",stream);
request.put("messages",messages);
System.out.println(JSON.toJSONString(request));
return JSON.toJSONString(request);
}
}
流式回答-输出到前端
这个地方分两种方式,一种是使用SSE(Server-Sent Events),一种是使用websocket
SSE方式
引入依赖
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.9.3</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>4.9.3</version>
</dependency>
配置跨域
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class CorsConfig {
@Bean
public WebMvcConfigurer corsConfigurer() {
return new WebMvcConfigurer() {
@Override
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.allowedOrigins("*");//允许域名访问,如果*,代表所有域名
}
};
}
}
前端代码
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>SSE Example</title>
</head>
<body>
<div id="sse-data" v-html="message"></div>
<script src="https://code.jquery.com/jquery-3.6.4.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/vue@2"></script>
<script>
const app = new Vue({
el: '#sse-data',
data: {
message: '',
},
created() {
this.createSseConnect();
},
methods: {
// 建立连接
createSseConnect(){
let clientId = 1;
if(window.EventSource){
const eventSource = new EventSource('http://localhost:8080/sse/connect?clientId=' + clientId);
console.log(eventSource)
eventSource.onmessage = (event) =>{
console.log("onmessage: "+event.data)
this.message = this.message + event.data
};
eventSource.onopen = (event) =>{
console.log("onopen:"+event)
};
eventSource.onerror = (event) =>{
console.log("onerror:"+event)
};
eventSource.close = (event) =>{
console.log("close :"+event)
};
}else{
console.log("你的浏览器不支持SSE~")
}
console.log(" 测试 打印")
},
},
});
</script>
</body>
</html
后端代码部分的写法和WebSocket非常的像,都是向创建连接,之后再发送消息
/**
* 这个用来保存用户与服务器之间的连接信息
*/
private static final Map<Long, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();
//创建连接
@GetMapping(value = "/sse/connect", produces="text/event-stream;charset=UTF-8")
public SseEmitter sseConnect(Long clientId){
//已经连接过,直接返回连接
if (sseEmitterMap.containsKey(clientId)) {
return sseEmitterMap.get(clientId);
}
try {
// 设置超时时间,0表示不过期。默认30秒
SseEmitter sseEmitter = new SseEmitter(30 * 1000L);
// 注册回调
sseEmitter.onCompletion(completionCallBack(clientId));
sseEmitter.onTimeout(timeoutCallBack(clientId));
sseEmitterMap.put(clientId, sseEmitter);
log.info("创建sse连接完成,当前客户端:{}", clientId);
return sseEmitter;
} catch (Exception e) {
log.info("创建sse连接异常,当前客户端:{}", clientId);
}
return null;
}
//发送消息,这里采用异步的方式来进行发送
/**
* 用来异步发送消息
*/
private final ExecutorService executorService = Executors.newCachedThreadPool();
/**
* SSE方式向前端发送消息
* @param clientId
* @param question
*/
@PostMapping(value = "/sse/chat")
public void streamOutputToPage(Long clientId,String question){
//异步发送消息
executorService.execute(() -> {
SseEmitter sseEmitter = sseEmitterMap.get(clientId);
if(sseEmitter == null){
sseEmitter = sseChat(clientId);
}
OkHttpClient client = new OkHttpClient();
String requestJson = wenxinUtils.constructRequestJson(1,0.95,1.0,true,messages);
RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
Request request = new Request.Builder()
.url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
//将回复的内容添加到消息中
HashMap<String, String> assistant = new HashMap<>();
assistant.put("role","assistant");
assistant.put("content","");
// 发起异步请求
try {
Response response = client.newCall(request).execute();
// 检查响应是否成功
if (response.isSuccessful()) {
// 获取响应流
try (ResponseBody responseBody = response.body()) {
if (responseBody != null) {
InputStream inputStream = responseBody.byteStream();
// 以流的方式处理响应内容,输出到控制台 这里的数组大小一定不能太小,否则会导致接收中文字符的时候产生乱码
byte[] buffer = new byte[2048];
int bytesRead;
StringBuilder temp = new StringBuilder();
while ((bytesRead = inputStream.read(buffer)) != -1) {
//TODO:这部分不需要使用\n\n来进行分割了,只需要将缓冲区开的尽可能大即可
//消息分割采用标识符 \n\n 来分割 并且需要从后向前找\n\n,因为每条消息分割点的最后才是\n\n
temp.append(new String(buffer, 0, bytesRead));
String result = "";
if(temp.lastIndexOf("\n\n") != -1){
//从6开始 因为有 data: 这个前缀 占了6个字符所以 0 + 6 = 6
JSONObject jsonObject = JSON.parseObject(temp.substring(6, temp.lastIndexOf("\n\n")));
temp = new StringBuilder(temp.substring(temp.lastIndexOf("\n\n") + 2));
if(jsonObject != null && jsonObject.getString("result") != null){
result = jsonObject.getString("result");
}
}
if(!result.equals("")){
//SSE协议默认是以两个\n换行符为结束标志 需要在进行一次转义才能成功发送给前端
sseEmitter.send(SseEmitter.event().data(result.replace("\n","\\n")));
//将结果汇总起来
assistant.put("content",assistant.get("content") + result);
}
}
messages.add(assistant);
}
}
} else {
System.out.println("Unexpected code " + response);
}
} catch (IOException e) {
log.error("流式请求出错,断开与{}的连接",clientId);
//移除当前的连接
sseEmitterMap.remove(clientId);
//移除本次对话的内容
messages.remove(user);
}
});
}
至此代码编写完毕,调用的方式也是非常的简单,首先将前端页面打开就可以在控制台的network部分看到连接了,之后我们想要发送问题的时候直接使用PostMan之类的API调试工具发送请求即可,具体操作看下图
question部分就是你的问题,clientId就是刚才在前端页面的clientId,表示服务器推送问题答案的服务器,最后在页面中即可看到显示出来的信息。
但是页面中的内容出现的非常突兀,没有一个字一个字的弹出,这个就不是后端的技术问题了,是前端的CSS样式,自行搜索即可。
问题1:如何向前端使用流式返回?
利用到的是SSE中的SseEmitter对象来向前端发送消息,其实流式只是一个称呼而已,最终的实际情况也不过是将消息拆分开来进行发送,这样做的好处就是让用户使用过程中更加的舒服,防止用户没有目的的等待。
问题2:如何分隔两条流式消息?
我们通过对返回数据的观察发现,所有的data之间都使用两个换行符来进行分割的,那我们就可以利用\n\n
来进行消息的分割'
问题3:如何保证\n\n一定是两条消息之间的分割而不是内容中的\n\n?
让缓冲区尽量大,保证每次都能将一条data数据读取完毕,这样就能保证\n\n一定是两条消息之间的分隔符,目前也只有这一种方法可行,根据我的测试来看,只要设置1024个字节以上的缓冲区,基本不会存在一次性读取不完数据的情况
问题3补充
由于我们缓冲区设置的足够大,所以我们甚至不需要再对消息进行分割,我们每次缓冲区读取到的数据就一定是独立的一条消息
问题4:发送给前端数据时数据产生截断问题?
SSE协议默认是以两个\n换行符为结束标志,所以在数据中如果有\n\n
存在的话就会导致后面的数据被截断,无法收到,解决办法就是再对\n进行一次转义即可
WebSocket方式
待补充
注意事项
本文并没有针对于token进行统计,需要用户自行进行代码来实现token的统计。
上述代码并没有进行抽取合并,大家可以自行抽取,后面我也会将我抽取的代码放在Github上供大家参考。